mirror of https://github.com/kortix-ai/suna.git
security wip
This commit is contained in:
parent
2fea768be9
commit
e01aa2e332
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Query
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from utils.config import config, EnvMode
|
||||
from utils.pagination import PaginationParams
|
||||
|
@ -21,7 +21,7 @@ router = APIRouter()
|
|||
async def update_agent(
|
||||
agent_id: str,
|
||||
agent_data: AgentUpdateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Updating agent {agent_id} for user: {user_id}")
|
||||
|
||||
|
@ -436,7 +436,7 @@ async def update_agent(
|
|||
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
|
||||
|
||||
@router.delete("/agents/{agent_id}")
|
||||
async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def delete_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
logger.debug(f"Deleting agent: {agent_id}")
|
||||
client = await utils.db.client
|
||||
|
||||
|
@ -502,7 +502,7 @@ async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id
|
|||
|
||||
@router.get("/agents", response_model=AgentsResponse)
|
||||
async def get_agents(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
|
||||
limit: Optional[int] = Query(20, ge=1, le=100, description="Number of items per page"),
|
||||
search: Optional[str] = Query(None, description="Search in name and description"),
|
||||
|
@ -570,7 +570,7 @@ async def get_agents(
|
|||
raise HTTPException(status_code=500, detail=f"Failed to fetch agents: {str(e)}")
|
||||
|
||||
@router.get("/agents/{agent_id}", response_model=AgentResponse)
|
||||
async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def get_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
|
||||
logger.debug(f"Fetching agent {agent_id} for user: {user_id}")
|
||||
|
||||
|
@ -697,7 +697,7 @@ async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_fr
|
|||
@router.post("/agents", response_model=AgentResponse)
|
||||
async def create_agent(
|
||||
agent_data: AgentCreateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Creating new agent for user: {user_id}")
|
||||
client = await utils.db.client
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional, Dict, List, Any
|
|||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from uuid import uuid4
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from templates.template_service import MCPRequirementValue, ConfigType, ProfileId, QualifiedName
|
||||
|
||||
|
@ -343,7 +343,7 @@ class JsonImportService:
|
|||
raise # Re-raise the exception to ensure import fails if version creation fails
|
||||
|
||||
@router.get("/agents/{agent_id}/export")
|
||||
async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def export_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
"""Export an agent configuration as JSON"""
|
||||
logger.debug(f"Exporting agent {agent_id} for user: {user_id}")
|
||||
|
||||
|
@ -419,7 +419,7 @@ async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id
|
|||
@router.post("/agents/json/analyze", response_model=JsonAnalysisResponse)
|
||||
async def analyze_json_for_import(
|
||||
request: JsonAnalysisRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Analyze imported JSON to determine required credentials and configurations"""
|
||||
logger.debug(f"Analyzing JSON for import - user: {user_id}")
|
||||
|
@ -439,7 +439,7 @@ async def analyze_json_for_import(
|
|||
@router.post("/agents/json/import", response_model=JsonImportResponse)
|
||||
async def import_agent_from_json(
|
||||
request: JsonImportRequestModel,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Importing agent from JSON - user: {user_id}")
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Optional, List
|
|||
from fastapi import APIRouter, HTTPException, Depends, Request, Body, File, UploadFile, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt, get_user_id_from_stream_auth, verify_and_authorize_thread_access
|
||||
from utils.logger import logger, structlog
|
||||
from services.billing import check_billing_status, can_use_model
|
||||
from utils.config import config
|
||||
|
@ -32,7 +32,7 @@ router = APIRouter()
|
|||
async def start_agent(
|
||||
thread_id: str,
|
||||
body: AgentStartRequest = Body(...),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Start an agent for a specific thread in the background"""
|
||||
structlog.contextvars.bind_contextvars(
|
||||
|
@ -67,7 +67,7 @@ async def start_agent(
|
|||
thread_metadata = thread_data.get('metadata', {})
|
||||
|
||||
if account_id != user_id:
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
|
||||
structlog.contextvars.bind_contextvars(
|
||||
project_id=project_id,
|
||||
|
@ -241,7 +241,7 @@ async def start_agent(
|
|||
return {"agent_run_id": agent_run_id, "status": "running"}
|
||||
|
||||
@router.post("/agent-run/{agent_run_id}/stop")
|
||||
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def stop_agent(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
"""Stop a running agent."""
|
||||
structlog.contextvars.bind_contextvars(
|
||||
agent_run_id=agent_run_id,
|
||||
|
@ -253,20 +253,20 @@ async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_
|
|||
return {"status": "stopped"}
|
||||
|
||||
@router.get("/thread/{thread_id}/agent-runs")
|
||||
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def get_agent_runs(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
"""Get all agent runs for a thread."""
|
||||
structlog.contextvars.bind_contextvars(
|
||||
thread_id=thread_id,
|
||||
)
|
||||
logger.debug(f"Fetching agent runs for thread: {thread_id}")
|
||||
client = await utils.db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
agent_runs = await client.table('agent_runs').select('id, thread_id, status, started_at, completed_at, error, created_at, updated_at').eq("thread_id", thread_id).order('created_at', desc=True).execute()
|
||||
logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}")
|
||||
return {"agent_runs": agent_runs.data}
|
||||
|
||||
@router.get("/agent-run/{agent_run_id}")
|
||||
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def get_agent_run(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
"""Get agent run status and responses."""
|
||||
structlog.contextvars.bind_contextvars(
|
||||
agent_run_id=agent_run_id,
|
||||
|
@ -285,7 +285,7 @@ async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_us
|
|||
}
|
||||
|
||||
@router.get("/thread/{thread_id}/agent", response_model=ThreadAgentResponse)
|
||||
async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
async def get_thread_agent(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
|
||||
"""Get the agent details for a specific thread. Since threads are fully agent-agnostic,
|
||||
this returns the most recently used agent from agent_runs only."""
|
||||
structlog.contextvars.bind_contextvars(
|
||||
|
@ -296,7 +296,7 @@ async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_us
|
|||
|
||||
try:
|
||||
# Verify thread access and get thread data
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
|
||||
|
||||
if not thread_result.data:
|
||||
|
@ -621,7 +621,7 @@ async def initiate_agent_with_files(
|
|||
enable_context_manager: Optional[bool] = Form(False),
|
||||
agent_id: Optional[str] = Form(None), # Add agent_id parameter
|
||||
files: List[UploadFile] = File(default=[]),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""
|
||||
Initiate a new agent session with optional file attachments.
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from sandbox.sandbox import get_or_start_sandbox
|
||||
from services.supabase import DBConnection
|
||||
|
@ -19,7 +19,7 @@ router = APIRouter()
|
|||
async def get_custom_mcp_tools_for_agent(
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Getting custom MCP tools for agent {agent_id}, user {user_id}")
|
||||
try:
|
||||
|
@ -98,7 +98,7 @@ async def get_custom_mcp_tools_for_agent(
|
|||
async def update_custom_mcp_tools_for_agent(
|
||||
agent_id: str,
|
||||
request: dict,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Updating custom MCP tools for agent {agent_id}, user {user_id}")
|
||||
|
||||
|
@ -209,7 +209,7 @@ async def update_custom_mcp_tools_for_agent(
|
|||
async def update_agent_custom_mcps(
|
||||
agent_id: str,
|
||||
request: dict,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Updating agent {agent_id} custom MCPs for user {user_id}")
|
||||
|
||||
|
@ -316,7 +316,7 @@ async def update_agent_custom_mcps(
|
|||
@router.get("/agents/{agent_id}/tools")
|
||||
async def get_agent_tools(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
logger.debug(f"Fetching enabled tools for agent: {agent_id} by user: {user_id}")
|
||||
|
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
|||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Form, Query
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, verify_thread_access
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_authorize_thread_access, require_thread_access, AuthorizedThreadAccess
|
||||
from utils.logger import logger
|
||||
from sandbox.sandbox import create_sandbox, delete_sandbox
|
||||
|
||||
|
@ -16,7 +16,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/threads")
|
||||
async def get_user_threads(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
|
||||
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
|
||||
):
|
||||
|
@ -121,14 +121,15 @@ async def get_user_threads(
|
|||
@router.get("/threads/{thread_id}")
|
||||
async def get_thread(
|
||||
thread_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
auth: AuthorizedThreadAccess = Depends(require_thread_access)
|
||||
):
|
||||
"""Get a specific thread by ID with complete related data."""
|
||||
logger.debug(f"Fetching thread: {thread_id}")
|
||||
client = await utils.db.client
|
||||
user_id = auth.user_id # Already authenticated and authorized!
|
||||
|
||||
try:
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
# No need for manual authorization - it's already done in the dependency!
|
||||
|
||||
# Get the thread data
|
||||
thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute()
|
||||
|
@ -200,7 +201,7 @@ async def get_thread(
|
|||
@router.post("/threads", response_model=CreateThreadResponse)
|
||||
async def create_thread(
|
||||
name: Optional[str] = Form(None),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""
|
||||
Create a new thread without starting an agent run.
|
||||
|
@ -303,13 +304,13 @@ async def create_thread(
|
|||
@router.get("/threads/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
|
||||
):
|
||||
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
|
||||
logger.debug(f"Fetching all messages for thread: {thread_id}, order={order}")
|
||||
client = await utils.db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
try:
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
|
@ -333,7 +334,7 @@ async def get_thread_messages(
|
|||
@router.get("/agent-runs/{agent_run_id}")
|
||||
async def get_agent_run(
|
||||
agent_run_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
):
|
||||
"""
|
||||
[DEPRECATED] Get an agent run by ID.
|
||||
|
@ -355,12 +356,12 @@ async def get_agent_run(
|
|||
async def add_message_to_thread(
|
||||
thread_id: str,
|
||||
message: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
):
|
||||
"""Add a message to a thread"""
|
||||
logger.debug(f"Adding message to thread: {thread_id}")
|
||||
client = await utils.db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
try:
|
||||
message_result = await client.table('messages').insert({
|
||||
'thread_id': thread_id,
|
||||
|
@ -380,14 +381,14 @@ async def add_message_to_thread(
|
|||
async def create_message(
|
||||
thread_id: str,
|
||||
message_data: MessageCreateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Create a new message in a thread."""
|
||||
logger.debug(f"Creating message in thread: {thread_id}")
|
||||
client = await utils.db.client
|
||||
|
||||
try:
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
|
||||
message_payload = {
|
||||
"role": "user" if message_data.type == "user" else "assistant",
|
||||
|
@ -421,12 +422,12 @@ async def create_message(
|
|||
async def delete_message(
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Delete a message from a thread."""
|
||||
logger.debug(f"Deleting message from thread: {thread_id}")
|
||||
client = await utils.db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
try:
|
||||
# Don't allow users to delete the "status" messages
|
||||
await client.table('messages').delete().eq('message_id', message_id).eq('is_llm_message', True).eq('thread_id', thread_id).execute()
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
|
||||
from .version_service import (
|
||||
get_version_service,
|
||||
|
@ -61,7 +61,7 @@ class VersionComparisonResponse(BaseModel):
|
|||
@router.get("/agents/{agent_id}/versions", response_model=List[VersionResponse])
|
||||
async def get_versions(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -80,7 +80,7 @@ async def get_versions(
|
|||
async def create_version(
|
||||
agent_id: str,
|
||||
request: CreateVersionRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -112,7 +112,7 @@ async def create_version(
|
|||
async def get_version(
|
||||
agent_id: str,
|
||||
version_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -131,7 +131,7 @@ async def get_version(
|
|||
async def activate_version(
|
||||
agent_id: str,
|
||||
version_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -153,7 +153,7 @@ async def compare_versions(
|
|||
agent_id: str,
|
||||
version1_id: str,
|
||||
version2_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -179,7 +179,7 @@ async def compare_versions(
|
|||
async def rollback_to_version(
|
||||
agent_id: str,
|
||||
version_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
@ -202,7 +202,7 @@ async def update_version_details(
|
|||
agent_id: str,
|
||||
version_id: str,
|
||||
request: UpdateVersionDetailsRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
version_service: VersionService = Depends(get_version_service)
|
||||
):
|
||||
try:
|
||||
|
|
|
@ -80,7 +80,7 @@ class VersionService:
|
|||
async def _get_client(self):
|
||||
return await self.db.client
|
||||
|
||||
async def _verify_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]:
|
||||
async def _verify_and_authorize_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]:
|
||||
if user_id == "system":
|
||||
return True, True
|
||||
|
||||
|
@ -205,7 +205,7 @@ class VersionService:
|
|||
logger.debug(f"Creating version for agent {agent_id}")
|
||||
client = await self.db.client
|
||||
|
||||
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner:
|
||||
raise UnauthorizedError("Unauthorized to create version for this agent")
|
||||
|
||||
|
@ -298,7 +298,7 @@ class VersionService:
|
|||
return version
|
||||
|
||||
async def get_version(self, agent_id: str, version_id: str, user_id: str) -> AgentVersion:
|
||||
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner and not is_public:
|
||||
raise UnauthorizedError("You don't have permission to view this version")
|
||||
|
||||
|
@ -314,7 +314,7 @@ class VersionService:
|
|||
return self._version_from_db_row(result.data[0])
|
||||
|
||||
async def get_active_version(self, agent_id: str, user_id: str = "system") -> Optional[AgentVersion]:
|
||||
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner and not is_public:
|
||||
raise UnauthorizedError("You don't have permission to view this agent")
|
||||
|
||||
|
@ -341,7 +341,7 @@ class VersionService:
|
|||
return version
|
||||
|
||||
async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]:
|
||||
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner and not is_public:
|
||||
raise UnauthorizedError("You don't have permission to view versions")
|
||||
|
||||
|
@ -355,7 +355,7 @@ class VersionService:
|
|||
return versions
|
||||
|
||||
async def activate_version(self, agent_id: str, version_id: str, user_id: str) -> None:
|
||||
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner:
|
||||
raise UnauthorizedError("You don't have permission to activate versions")
|
||||
|
||||
|
@ -458,7 +458,7 @@ class VersionService:
|
|||
) -> AgentVersion:
|
||||
version_to_restore = await self.get_version(agent_id, version_id, user_id)
|
||||
|
||||
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner:
|
||||
raise UnauthorizedError("You don't have permission to rollback versions")
|
||||
|
||||
|
@ -483,7 +483,7 @@ class VersionService:
|
|||
version_name: Optional[str] = None,
|
||||
change_description: Optional[str] = None
|
||||
) -> AgentVersion:
|
||||
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
|
||||
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
|
||||
if not is_owner:
|
||||
raise UnauthorizedError("You don't have permission to update this version")
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from agent.tools.expand_msg_tool import ExpandMessageTool
|
|||
from agent.prompts.prompt import get_system_prompt
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_account_id_from_thread
|
||||
|
||||
from services.billing import check_billing_status
|
||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
||||
|
@ -448,9 +448,16 @@ class AgentRunner:
|
|||
)
|
||||
|
||||
self.client = await self.thread_manager.db.client
|
||||
self.account_id = await get_account_id_from_thread(self.client, self.config.thread_id)
|
||||
|
||||
response = await self.client.table('threads').select('account_id').eq('thread_id', self.config.thread_id).execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise ValueError(f"Thread {self.config.thread_id} not found")
|
||||
|
||||
self.account_id = response.data[0].get('account_id')
|
||||
|
||||
if not self.account_id:
|
||||
raise ValueError("Could not determine account ID for thread")
|
||||
raise ValueError(f"Thread {self.config.thread_id} has no associated account")
|
||||
|
||||
project = await self.client.table('projects').select('*').eq('project_id', self.config.project_id).execute()
|
||||
if not project.data or len(project.data) == 0:
|
||||
|
|
|
@ -7,7 +7,7 @@ from fastapi import HTTPException
|
|||
from utils.cache import Cache
|
||||
from utils.logger import logger
|
||||
from utils.config import config
|
||||
from utils.auth_utils import verify_thread_access
|
||||
from utils.auth_utils import verify_and_authorize_thread_access
|
||||
from services import redis
|
||||
from services.supabase import DBConnection
|
||||
from services.llm import make_llm_api_call
|
||||
|
@ -121,7 +121,7 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st
|
|||
account_id = agent_run_data['threads']['account_id']
|
||||
if account_id == user_id:
|
||||
return agent_run_data
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
return agent_run_data
|
||||
|
||||
async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||
|
|
|
@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
|
|||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from uuid import uuid4
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, get_optional_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt, get_optional_current_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from services.supabase import DBConnection
|
||||
from datetime import datetime
|
||||
|
@ -215,7 +215,7 @@ class ProfileResponse(BaseModel):
|
|||
|
||||
@router.get("/categories")
|
||||
async def list_categories(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
logger.debug("Fetching Composio categories")
|
||||
|
@ -240,7 +240,7 @@ async def list_toolkits(
|
|||
cursor: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
category: Optional[str] = Query(None),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
logger.debug(f"Fetching Composio toolkits with limit: {limit}, cursor: {cursor}, search: {search}, category: {category}")
|
||||
|
@ -270,7 +270,7 @@ async def list_toolkits(
|
|||
@router.get("/toolkits/{toolkit_slug}/details")
|
||||
async def get_toolkit_details(
|
||||
toolkit_slug: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
logger.debug(f"Fetching detailed toolkit info for: {toolkit_slug}")
|
||||
|
@ -296,7 +296,7 @@ async def get_toolkit_details(
|
|||
@router.post("/integrate", response_model=IntegrationStatusResponse)
|
||||
async def integrate_toolkit(
|
||||
request: IntegrateToolkitRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> IntegrationStatusResponse:
|
||||
try:
|
||||
integration_user_id = str(uuid4())
|
||||
|
@ -333,7 +333,7 @@ async def integrate_toolkit(
|
|||
@router.post("/profiles", response_model=ProfileResponse)
|
||||
async def create_profile(
|
||||
request: CreateProfileRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> ProfileResponse:
|
||||
try:
|
||||
integration_user_id = str(uuid4())
|
||||
|
@ -378,7 +378,7 @@ async def create_profile(
|
|||
@router.get("/profiles")
|
||||
async def get_profiles(
|
||||
toolkit_slug: Optional[str] = Query(None),
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
profile_service = ComposioProfileService(db)
|
||||
|
@ -403,7 +403,7 @@ async def get_profiles(
|
|||
@router.get("/profiles/{profile_id}/mcp-config")
|
||||
async def get_profile_mcp_config(
|
||||
profile_id: str,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
profile_service = ComposioProfileService(db)
|
||||
|
@ -423,7 +423,7 @@ async def get_profile_mcp_config(
|
|||
@router.get("/profiles/{profile_id}")
|
||||
async def get_profile_info(
|
||||
profile_id: str,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
profile_service = ComposioProfileService(db)
|
||||
|
@ -452,7 +452,7 @@ async def get_profile_info(
|
|||
@router.get("/integration/{connected_account_id}/status")
|
||||
async def get_integration_status(
|
||||
connected_account_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
service = get_integration_service()
|
||||
|
@ -466,7 +466,7 @@ async def get_integration_status(
|
|||
@router.post("/profiles/{profile_id}/discover-tools")
|
||||
async def discover_composio_tools(
|
||||
profile_id: str,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
profile_service = ComposioProfileService(db)
|
||||
|
@ -509,7 +509,7 @@ async def discover_composio_tools(
|
|||
@router.post("/discover-tools/{profile_id}")
|
||||
async def discover_tools_post(
|
||||
profile_id: str,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict[str, Any]:
|
||||
return await discover_composio_tools(profile_id, current_user_id)
|
||||
|
||||
|
@ -544,7 +544,7 @@ async def get_toolkit_icon(
|
|||
@router.post("/tools/list")
|
||||
async def list_toolkit_tools(
|
||||
request: ToolsListRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
logger.debug(f"User {current_user_id} requesting tools for toolkit: {request.toolkit_slug}")
|
||||
|
@ -572,7 +572,7 @@ async def list_toolkit_tools(
|
|||
|
||||
@router.get("/triggers/apps")
|
||||
async def list_apps_with_triggers(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
trigger_service = ComposioTriggerService()
|
||||
|
@ -588,7 +588,7 @@ async def list_apps_with_triggers(
|
|||
@router.get("/triggers/apps/{toolkit_slug}")
|
||||
async def list_triggers_for_app(
|
||||
toolkit_slug: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
trigger_service = ComposioTriggerService()
|
||||
|
@ -617,7 +617,7 @@ class CreateComposioTriggerRequest(BaseModel):
|
|||
|
||||
|
||||
@router.post("/triggers/create")
|
||||
async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(get_current_user_id_from_jwt)) -> Dict[str, Any]:
|
||||
async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(verify_and_get_user_id_from_jwt)) -> Dict[str, Any]:
|
||||
try:
|
||||
client_db = await db.client
|
||||
agent_check = await client_db.table('agents').select('agent_id').eq('agent_id', req.agent_id).eq('account_id', current_user_id).execute()
|
||||
|
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel, validator
|
|||
import urllib.parse
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from services.supabase import DBConnection
|
||||
|
||||
from .credential_service import (
|
||||
|
@ -116,7 +116,7 @@ def initialize(database: DBConnection):
|
|||
@router.post("/credentials", response_model=CredentialResponse)
|
||||
async def store_credential(
|
||||
request: StoreCredentialRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
credential_service = get_credential_service(db)
|
||||
|
@ -151,7 +151,7 @@ async def store_credential(
|
|||
|
||||
@router.get("/credentials", response_model=List[CredentialResponse])
|
||||
async def get_user_credentials(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
credential_service = get_credential_service(db)
|
||||
|
@ -178,7 +178,7 @@ async def get_user_credentials(
|
|||
@router.delete("/credentials/{mcp_qualified_name:path}")
|
||||
async def delete_credential(
|
||||
mcp_qualified_name: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
decoded_name = decode_mcp_qualified_name(mcp_qualified_name)
|
||||
|
@ -201,7 +201,7 @@ async def delete_credential(
|
|||
@router.post("/credential-profiles", response_model=CredentialProfileResponse)
|
||||
async def store_credential_profile(
|
||||
request: StoreCredentialProfileRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -240,7 +240,7 @@ async def store_credential_profile(
|
|||
|
||||
@router.get("/credential-profiles", response_model=List[CredentialProfileResponse])
|
||||
async def get_user_credential_profiles(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -269,7 +269,7 @@ async def get_user_credential_profiles(
|
|||
@router.get("/credential-profiles/{mcp_qualified_name:path}", response_model=List[CredentialProfileResponse])
|
||||
async def get_credential_profiles_for_mcp(
|
||||
mcp_qualified_name: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
decoded_name = decode_mcp_qualified_name(mcp_qualified_name)
|
||||
|
@ -300,7 +300,7 @@ async def get_credential_profiles_for_mcp(
|
|||
@router.get("/credential-profiles/profile/{profile_id}", response_model=CredentialProfileResponse)
|
||||
async def get_credential_profile(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -331,7 +331,7 @@ async def get_credential_profile(
|
|||
@router.put("/credential-profiles/{profile_id}/set-default")
|
||||
async def set_default_credential_profile(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -350,7 +350,7 @@ async def set_default_credential_profile(
|
|||
@router.delete("/credential-profiles/{profile_id}")
|
||||
async def delete_credential_profile(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -369,7 +369,7 @@ async def delete_credential_profile(
|
|||
@router.post("/credential-profiles/bulk-delete", response_model=BulkDeleteProfilesResponse)
|
||||
async def bulk_delete_credential_profiles(
|
||||
request: BulkDeleteProfilesRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -399,7 +399,7 @@ async def bulk_delete_credential_profiles(
|
|||
|
||||
@router.get("/composio-profiles", response_model=ComposioCredentialsResponse)
|
||||
async def get_composio_profiles(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
profile_service = get_profile_service(db)
|
||||
|
@ -482,7 +482,7 @@ async def get_composio_profiles(
|
|||
@router.get("/composio-profiles/{profile_id}/mcp-url", response_model=ComposioMcpUrlResponse)
|
||||
async def get_composio_mcp_url(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
from composio_integration.composio_profile_service import ComposioProfileService
|
||||
|
|
|
@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Depends
|
|||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from services.supabase import DBConnection
|
||||
from .google_slides_service import GoogleSlidesService, OAuthTokenService
|
||||
|
@ -92,7 +92,7 @@ presentation_router = APIRouter(prefix="/presentation-tools", tags=["presentatio
|
|||
|
||||
@oauth_router.get("/auth-url", response_model=AuthURLResponse)
|
||||
async def get_google_auth_url(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
return_url: Optional[str] = Query(None, description="URL to redirect to after OAuth"),
|
||||
google_service: GoogleSlidesService = Depends(get_google_slides_service)
|
||||
):
|
||||
|
@ -172,7 +172,7 @@ async def google_oauth_callback(
|
|||
# UNUSED: Disconnect endpoint - frontend never calls this
|
||||
# @oauth_router.post("/disconnect")
|
||||
# async def disconnect_google(
|
||||
# user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
# user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
# google_service: GoogleSlidesService = Depends(get_google_slides_service)
|
||||
# ):
|
||||
# """
|
||||
|
@ -196,7 +196,7 @@ async def google_oauth_callback(
|
|||
@presentation_router.post("/convert-and-upload-to-slides", response_model=ConvertToSlidesResponse)
|
||||
async def convert_and_upload_to_google_slides(
|
||||
request: ConvertToSlidesRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
google_service: GoogleSlidesService = Depends(get_google_slides_service)
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, BackgroundTasks
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, verify_agent_access
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_get_agent_authorization, require_agent_access, AuthorizedAgentAccess
|
||||
from services.supabase import DBConnection
|
||||
from knowledge_base.file_processor import FileProcessor
|
||||
from utils.logger import logger
|
||||
|
@ -69,15 +69,16 @@ db = DBConnection()
|
|||
async def get_agent_knowledge_base(
|
||||
agent_id: str,
|
||||
include_inactive: bool = False,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
auth: AuthorizedAgentAccess = Depends(require_agent_access)
|
||||
):
|
||||
|
||||
"""Get all knowledge base entries for an agent"""
|
||||
try:
|
||||
client = await db.client
|
||||
user_id = auth.user_id # Already authenticated and authorized!
|
||||
agent_data = auth.agent_data # Agent data already fetched during authorization
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
# No need for manual authorization - it's already done in the dependency!
|
||||
|
||||
result = await client.rpc('get_agent_knowledge_base', {
|
||||
'p_agent_id': agent_id,
|
||||
|
@ -122,7 +123,7 @@ async def get_agent_knowledge_base(
|
|||
async def create_agent_knowledge_base_entry(
|
||||
agent_id: str,
|
||||
entry_data: CreateKnowledgeBaseEntryRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Create a new knowledge base entry for an agent"""
|
||||
|
@ -130,7 +131,7 @@ async def create_agent_knowledge_base_entry(
|
|||
client = await db.client
|
||||
|
||||
# Verify agent access and get agent data
|
||||
agent_data = await verify_agent_access(client, agent_id, user_id)
|
||||
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
account_id = agent_data['account_id']
|
||||
|
||||
insert_data = {
|
||||
|
@ -172,7 +173,7 @@ async def upload_file_to_agent_kb(
|
|||
agent_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Upload and process a file for agent knowledge base"""
|
||||
|
@ -180,7 +181,7 @@ async def upload_file_to_agent_kb(
|
|||
client = await db.client
|
||||
|
||||
# Verify agent access and get agent data
|
||||
agent_data = await verify_agent_access(client, agent_id, user_id)
|
||||
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
account_id = agent_data['account_id']
|
||||
|
||||
file_content = await file.read()
|
||||
|
@ -226,7 +227,7 @@ async def upload_file_to_agent_kb(
|
|||
async def update_knowledge_base_entry(
|
||||
entry_id: str,
|
||||
entry_data: UpdateKnowledgeBaseEntryRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Update an agent knowledge base entry"""
|
||||
|
@ -243,7 +244,7 @@ async def update_knowledge_base_entry(
|
|||
agent_id = entry['agent_id']
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
update_data = {}
|
||||
if entry_data.name is not None:
|
||||
|
@ -294,7 +295,7 @@ async def update_knowledge_base_entry(
|
|||
@router.delete("/{entry_id}")
|
||||
async def delete_knowledge_base_entry(
|
||||
entry_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Delete an agent knowledge base entry"""
|
||||
|
@ -311,7 +312,7 @@ async def delete_knowledge_base_entry(
|
|||
agent_id = entry['agent_id']
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
result = await client.table('agent_knowledge_base_entries').delete().eq('entry_id', entry_id).execute()
|
||||
|
||||
|
@ -329,7 +330,7 @@ async def delete_knowledge_base_entry(
|
|||
@router.get("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
|
||||
async def get_knowledge_base_entry(
|
||||
entry_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get a specific agent knowledge base entry"""
|
||||
try:
|
||||
|
@ -345,7 +346,7 @@ async def get_knowledge_base_entry(
|
|||
agent_id = entry['agent_id']
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
logger.debug(f"Retrieved agent knowledge base entry {entry_id} for agent {agent_id}")
|
||||
|
||||
|
@ -376,7 +377,7 @@ async def get_knowledge_base_entry(
|
|||
async def get_agent_processing_jobs(
|
||||
agent_id: str,
|
||||
limit: int = 10,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Get processing jobs for an agent"""
|
||||
|
@ -384,7 +385,7 @@ async def get_agent_processing_jobs(
|
|||
client = await db.client
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
result = await client.rpc('get_agent_kb_processing_jobs', {
|
||||
'p_agent_id': agent_id,
|
||||
|
@ -468,7 +469,7 @@ async def process_file_background(
|
|||
async def get_agent_knowledge_base_context(
|
||||
agent_id: str,
|
||||
max_tokens: int = 4000,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
"""Get knowledge base context for agent prompts"""
|
||||
|
@ -476,7 +477,7 @@ async def get_agent_knowledge_base_context(
|
|||
client = await db.client
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(client, agent_id, user_id)
|
||||
await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
result = await client.rpc('get_agent_knowledge_base_context', {
|
||||
'p_agent_id': agent_id,
|
||||
|
|
|
@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends
|
|||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from .mcp_service import mcp_service, MCPException
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from uuid import UUID
|
|||
from datetime import datetime
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from .profile_service import ProfileService, Profile, ProfileServiceError, ProfileNotFoundError, ProfileAlreadyExistsError, InvalidConfigError, EncryptionError
|
||||
from .connection_service import ConnectionService
|
||||
from .app_service import get_app_service
|
||||
|
@ -162,7 +162,7 @@ def _handle_pipedream_exception(e: Exception) -> HTTPException:
|
|||
@router.post("/connection-token", response_model=ConnectionTokenResponse)
|
||||
async def create_connection_token(
|
||||
request: CreateConnectionTokenRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Creating Pipedream connection token for user: {user_id}, app: {request.app}")
|
||||
|
||||
|
@ -190,7 +190,7 @@ async def create_connection_token(
|
|||
|
||||
@router.get("/connections", response_model=ConnectionResponse)
|
||||
async def get_user_connections(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Getting connections for user: {user_id}")
|
||||
|
||||
|
@ -228,7 +228,7 @@ async def get_user_connections(
|
|||
@router.post("/mcp/discover", response_model=MCPDiscoveryResponse)
|
||||
async def discover_mcp_servers(
|
||||
request: MCPDiscoveryRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Discovering MCP servers for user: {user_id}, app: {request.app_slug}")
|
||||
|
||||
|
@ -277,7 +277,7 @@ async def discover_mcp_servers(
|
|||
@router.post("/mcp/discover-profile", response_model=MCPDiscoveryResponse)
|
||||
async def discover_mcp_servers_for_profile(
|
||||
request: MCPProfileDiscoveryRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Discovering MCP servers for profile: {request.external_user_id}")
|
||||
|
||||
|
@ -326,7 +326,7 @@ async def discover_mcp_servers_for_profile(
|
|||
@router.post("/mcp/connect", response_model=MCPConnectionResponse)
|
||||
async def create_mcp_connection(
|
||||
request: MCPConnectionRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Creating MCP connection for user: {user_id}, app: {request.app_slug}")
|
||||
|
||||
|
@ -535,7 +535,7 @@ async def get_app_tools(app_slug: str):
|
|||
@router.post("/profiles", response_model=ProfileResponse)
|
||||
async def create_credential_profile(
|
||||
request: ProfileRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Creating credential profile for user: {user_id}, app: {request.app_slug}")
|
||||
|
||||
|
@ -563,7 +563,7 @@ async def create_credential_profile(
|
|||
async def get_credential_profiles(
|
||||
app_slug: Optional[str] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Getting credential profiles for user: {user_id}, app: {app_slug}")
|
||||
|
||||
|
@ -581,7 +581,7 @@ async def get_credential_profiles(
|
|||
@router.get("/profiles/{profile_id}", response_model=ProfileResponse)
|
||||
async def get_credential_profile(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Getting credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
|
@ -603,7 +603,7 @@ async def get_credential_profile(
|
|||
async def update_credential_profile(
|
||||
profile_id: str,
|
||||
request: UpdateProfileRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Updating credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
|
@ -628,7 +628,7 @@ async def update_credential_profile(
|
|||
@router.delete("/profiles/{profile_id}")
|
||||
async def delete_credential_profile(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Deleting credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
|
@ -649,7 +649,7 @@ async def delete_credential_profile(
|
|||
async def connect_credential_profile(
|
||||
profile_id: str,
|
||||
app: Optional[str] = Query(None),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Connecting credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
|
@ -686,7 +686,7 @@ async def connect_credential_profile(
|
|||
@router.get("/profiles/{profile_id}/connections")
|
||||
async def get_profile_connections(
|
||||
profile_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
logger.debug(f"Getting connections for profile: {profile_id}, user: {user_id}")
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from services.api_keys import (
|
|||
APIKeyCreateResponse,
|
||||
)
|
||||
from services.supabase import DBConnection
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -61,7 +61,7 @@ async def get_account_id_from_user_id(user_id: str) -> UUID:
|
|||
@router.post("/api-keys", response_model=APIKeyCreateResponse)
|
||||
async def create_api_key(
|
||||
request: APIKeyCreateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
api_key_service: APIKeyService = Depends(get_api_key_service),
|
||||
):
|
||||
"""
|
||||
|
@ -105,7 +105,7 @@ async def create_api_key(
|
|||
|
||||
@router.get("/api-keys", response_model=List[APIKeyResponse])
|
||||
async def list_api_keys(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
api_key_service: APIKeyService = Depends(get_api_key_service),
|
||||
):
|
||||
"""
|
||||
|
@ -141,7 +141,7 @@ async def list_api_keys(
|
|||
@router.patch("/api-keys/{key_id}/revoke")
|
||||
async def revoke_api_key(
|
||||
key_id: UUID,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
api_key_service: APIKeyService = Depends(get_api_key_service),
|
||||
):
|
||||
"""
|
||||
|
@ -185,7 +185,7 @@ async def revoke_api_key(
|
|||
@router.delete("/api-keys/{key_id}")
|
||||
async def delete_api_key(
|
||||
key_id: UUID,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt),
|
||||
api_key_service: APIKeyService = Depends(get_api_key_service),
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@ from utils.cache import Cache
|
|||
from utils.logger import logger
|
||||
from utils.config import config, EnvMode
|
||||
from services.supabase import DBConnection
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from pydantic import BaseModel
|
||||
from models import model_manager
|
||||
from litellm.cost_calculator import cost_per_token
|
||||
|
@ -1234,7 +1234,7 @@ async def handle_usage_with_credits(
|
|||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Create a Stripe Checkout session or modify an existing subscription."""
|
||||
try:
|
||||
|
@ -1555,7 +1555,7 @@ async def create_checkout_session(
|
|||
@router.post("/create-portal-session")
|
||||
async def create_portal_session(
|
||||
request: CreatePortalSessionRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Create a Stripe Customer Portal session for subscription management."""
|
||||
try:
|
||||
|
@ -1655,7 +1655,7 @@ async def create_portal_session(
|
|||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get the current subscription status for the current user, including scheduled changes and credit balance."""
|
||||
try:
|
||||
|
@ -1891,7 +1891,7 @@ async def get_subscription(
|
|||
|
||||
@router.get("/check-status")
|
||||
async def check_status(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Check if the user can run agents based on their subscription, usage, and credit balance."""
|
||||
try:
|
||||
|
@ -2076,7 +2076,7 @@ async def stripe_webhook(request: Request):
|
|||
|
||||
@router.get("/available-models")
|
||||
async def get_available_models(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get the list of models available to the user based on their subscription tier."""
|
||||
try:
|
||||
|
@ -2200,7 +2200,7 @@ async def get_available_models(
|
|||
async def get_usage_logs_endpoint(
|
||||
page: int = 0,
|
||||
items_per_page: int = 1000,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get detailed usage logs for a user with pagination."""
|
||||
logger.debug(f"[USAGE_LOGS_ENDPOINT] Starting get_usage_logs_endpoint for user_id={current_user_id}, page={page}, items_per_page={items_per_page}")
|
||||
|
@ -2254,7 +2254,7 @@ async def get_usage_logs_endpoint(
|
|||
@router.get("/subscription-commitment/{subscription_id}")
|
||||
async def get_subscription_commitment(
|
||||
subscription_id: str,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get commitment status for a subscription."""
|
||||
try:
|
||||
|
@ -2278,7 +2278,7 @@ async def get_subscription_commitment(
|
|||
|
||||
@router.get("/subscription-details")
|
||||
async def get_subscription_details(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get detailed subscription information including commitment status."""
|
||||
try:
|
||||
|
@ -2314,7 +2314,7 @@ async def get_subscription_details(
|
|||
|
||||
@router.post("/cancel-subscription")
|
||||
async def cancel_subscription(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Cancel subscription with yearly commitment handling."""
|
||||
try:
|
||||
|
@ -2412,7 +2412,7 @@ async def cancel_subscription(
|
|||
|
||||
@router.post("/reactivate-subscription")
|
||||
async def reactivate_subscription(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Reactivate a subscription that was marked for cancellation."""
|
||||
try:
|
||||
|
@ -2492,7 +2492,7 @@ async def reactivate_subscription(
|
|||
@router.post("/purchase-credits")
|
||||
async def purchase_credits(
|
||||
request: PurchaseCreditsRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""
|
||||
Create a Stripe checkout session for purchasing credits.
|
||||
|
@ -2609,7 +2609,7 @@ async def purchase_credits(
|
|||
|
||||
@router.get("/credit-balance")
|
||||
async def get_credit_balance_endpoint(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get the current credit balance for the user."""
|
||||
try:
|
||||
|
@ -2628,7 +2628,7 @@ async def get_credit_balance_endpoint(
|
|||
async def get_credit_history(
|
||||
page: int = 0,
|
||||
items_per_page: int = 50,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get credit purchase and usage history for the user."""
|
||||
try:
|
||||
|
@ -2690,7 +2690,7 @@ async def get_credit_history(
|
|||
|
||||
@router.get("/can-purchase-credits")
|
||||
async def can_purchase_credits(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Check if the current user can purchase credits (must be on highest tier)."""
|
||||
try:
|
||||
|
|
|
@ -5,7 +5,7 @@ from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
|
||||
router = APIRouter(tags=["transcription"])
|
||||
|
||||
|
@ -15,7 +15,7 @@ class TranscriptionResponse(BaseModel):
|
|||
@router.post("/transcription", response_model=TranscriptionResponse)
|
||||
async def transcribe_audio(
|
||||
audio_file: UploadFile = File(...),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Transcribe audio file to text using OpenAI Whisper."""
|
||||
try:
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
-- Remove the problematic GRANT that bypasses RLS
|
||||
REVOKE SELECT ON TABLE projects FROM anon;
|
||||
REVOKE SELECT ON TABLE threads FROM anon;
|
||||
REVOKE SELECT ON TABLE messages FROM anon;
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from services.supabase import DBConnection
|
||||
from utils.pagination import PaginationParams
|
||||
|
||||
|
@ -133,7 +133,7 @@ async def validate_agent_ownership(agent_id: str, user_id: str) -> Dict[str, Any
|
|||
@router.post("", response_model=Dict[str, str])
|
||||
async def create_template_from_agent(
|
||||
request: CreateTemplateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
await validate_agent_ownership(request.agent_id, user_id)
|
||||
|
@ -172,7 +172,7 @@ async def create_template_from_agent(
|
|||
async def publish_template(
|
||||
template_id: str,
|
||||
request: PublishTemplateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
template = await validate_template_ownership_and_get(template_id, user_id)
|
||||
|
@ -200,7 +200,7 @@ async def publish_template(
|
|||
@router.post("/{template_id}/unpublish")
|
||||
async def unpublish_template(
|
||||
template_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
template = await validate_template_ownership_and_get(template_id, user_id)
|
||||
|
@ -228,7 +228,7 @@ async def unpublish_template(
|
|||
@router.delete("/{template_id}")
|
||||
async def delete_template(
|
||||
template_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
template = await validate_template_ownership_and_get(template_id, user_id)
|
||||
|
@ -256,7 +256,7 @@ async def delete_template(
|
|||
@router.post("/install", response_model=InstallationResponse)
|
||||
async def install_template(
|
||||
request: InstallTemplateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
await validate_template_access_and_get(request.template_id, user_id)
|
||||
|
@ -345,8 +345,8 @@ async def get_marketplace_templates(
|
|||
creator_id_filter = None
|
||||
if mine:
|
||||
try:
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
user_id = await get_current_user_id_from_jwt(request)
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
user_id = await verify_and_get_user_id_from_jwt(request)
|
||||
creator_id_filter = user_id
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail="Authentication required for 'mine' filter")
|
||||
|
@ -406,7 +406,7 @@ async def get_my_templates(
|
|||
search: Optional[str] = Query(None, description="Search term for name and description"),
|
||||
sort_by: Optional[str] = Query("created_at", description="Sort field: created_at, name, download_count"),
|
||||
sort_order: Optional[str] = Query("desc", description="Sort order: asc, desc"),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
from templates.services.template_service import TemplateService, MarketplaceFilters
|
||||
|
@ -497,7 +497,7 @@ async def get_public_template(template_id: str):
|
|||
@router.get("/{template_id}", response_model=TemplateResponse)
|
||||
async def get_template(
|
||||
template_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
template = await validate_template_access_and_get(template_id, user_id)
|
||||
|
|
|
@ -9,7 +9,7 @@ import json
|
|||
import hmac
|
||||
|
||||
from services.supabase import DBConnection
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from utils.logger import logger
|
||||
from utils.config import config
|
||||
from services.billing import check_billing_status, can_use_model
|
||||
|
@ -131,7 +131,7 @@ def initialize(database: DBConnection):
|
|||
db = database
|
||||
|
||||
|
||||
async def verify_agent_access(agent_id: str, user_id: str):
|
||||
async def verify_and_authorize_trigger_agent_access(agent_id: str, user_id: str):
|
||||
client = await db.client
|
||||
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', user_id).execute()
|
||||
|
||||
|
@ -228,10 +228,10 @@ async def get_providers():
|
|||
@router.get("/agents/{agent_id}/triggers", response_model=List[TriggerResponse])
|
||||
async def get_agent_triggers(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
try:
|
||||
trigger_service = get_trigger_service(db)
|
||||
|
@ -266,7 +266,7 @@ async def get_agent_triggers(
|
|||
|
||||
@router.get("/all", response_model=List[Dict[str, Any]])
|
||||
async def get_all_user_triggers(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
try:
|
||||
client = await db.client
|
||||
|
@ -347,11 +347,11 @@ async def get_all_user_triggers(
|
|||
async def get_agent_upcoming_runs(
|
||||
agent_id: str,
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get upcoming scheduled runs for agent triggers"""
|
||||
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
try:
|
||||
trigger_service = get_trigger_service(db)
|
||||
|
@ -419,11 +419,11 @@ async def get_agent_upcoming_runs(
|
|||
async def create_agent_trigger(
|
||||
agent_id: str,
|
||||
request: TriggerCreateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Create a new trigger for an agent"""
|
||||
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
try:
|
||||
trigger_service = get_trigger_service(db)
|
||||
|
@ -466,7 +466,7 @@ async def create_agent_trigger(
|
|||
@router.get("/{trigger_id}", response_model=TriggerResponse)
|
||||
async def get_trigger(
|
||||
trigger_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get a trigger by ID"""
|
||||
|
||||
|
@ -477,7 +477,7 @@ async def get_trigger(
|
|||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
await verify_agent_access(trigger.agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
|
||||
|
||||
base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000")
|
||||
webhook_url = f"{base_url}/api/triggers/{trigger_id}/webhook"
|
||||
|
@ -505,7 +505,7 @@ async def get_trigger(
|
|||
async def update_trigger(
|
||||
trigger_id: str,
|
||||
request: TriggerUpdateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Update a trigger"""
|
||||
|
||||
|
@ -516,7 +516,7 @@ async def update_trigger(
|
|||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
await verify_agent_access(trigger.agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
|
||||
|
||||
updated_trigger = await trigger_service.update_trigger(
|
||||
trigger_id=trigger_id,
|
||||
|
@ -556,7 +556,7 @@ async def update_trigger(
|
|||
@router.delete("/{trigger_id}")
|
||||
async def delete_trigger(
|
||||
trigger_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Delete a trigger"""
|
||||
|
||||
|
@ -566,7 +566,7 @@ async def delete_trigger(
|
|||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
await verify_agent_access(trigger.agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
|
||||
|
||||
# Store agent_id before deletion
|
||||
agent_id = trigger.agent_id
|
||||
|
@ -709,10 +709,10 @@ def convert_steps_to_json(steps: List[WorkflowStepRequest]) -> List[Dict[str, An
|
|||
@workflows_router.get("/agents/{agent_id}/workflows")
|
||||
async def get_agent_workflows(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Get workflows for an agent"""
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
client = await db.client
|
||||
result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).order('created_at', desc=True).execute()
|
||||
|
@ -724,10 +724,10 @@ async def get_agent_workflows(
|
|||
async def create_agent_workflow(
|
||||
agent_id: str,
|
||||
workflow_data: WorkflowCreateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Create a new workflow for an agent"""
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
try:
|
||||
client = await db.client
|
||||
|
@ -758,10 +758,10 @@ async def update_agent_workflow(
|
|||
agent_id: str,
|
||||
workflow_id: str,
|
||||
workflow_data: WorkflowUpdateRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
"""Update a workflow"""
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
client = await db.client
|
||||
|
||||
|
@ -800,9 +800,9 @@ async def update_agent_workflow(
|
|||
async def delete_agent_workflow(
|
||||
agent_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
client = await db.client
|
||||
|
||||
|
@ -822,10 +822,10 @@ async def execute_agent_workflow(
|
|||
agent_id: str,
|
||||
workflow_id: str,
|
||||
execution_data: WorkflowExecuteRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
):
|
||||
print("DEBUG: Executing workflow", workflow_id, "for agent", agent_id)
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
|
||||
|
||||
client = await db.client
|
||||
|
||||
|
|
|
@ -6,9 +6,93 @@ from jwt.exceptions import PyJWTError
|
|||
from utils.logger import structlog
|
||||
from utils.config import config
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from services.supabase import DBConnection
|
||||
from services import redis
|
||||
|
||||
async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)):
|
||||
if not config.KORTIX_ADMIN_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Admin API key not configured on server"
|
||||
)
|
||||
|
||||
if not x_admin_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Admin API key required. Include X-Admin-Api-Key header."
|
||||
)
|
||||
|
||||
if x_admin_api_key != config.KORTIX_ADMIN_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid admin API key"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _verify_jwt_signature(token: str, secret: str) -> dict:
|
||||
"""
|
||||
Securely verify JWT signature using the Supabase JWT secret.
|
||||
|
||||
Args:
|
||||
token: The JWT token to verify
|
||||
secret: The Supabase JWT secret for signature verification
|
||||
|
||||
Returns:
|
||||
dict: The decoded JWT payload if signature is valid
|
||||
|
||||
Raises:
|
||||
PyJWTError: If signature verification fails or token is invalid
|
||||
"""
|
||||
try:
|
||||
# Decode and verify the JWT signature using HS256 algorithm (Supabase default)
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Supabase JWTs may not always have audience
|
||||
"verify_iss": False # Supabase JWTs may not always have issuer
|
||||
}
|
||||
)
|
||||
return payload
|
||||
except PyJWTError as e:
|
||||
structlog.get_logger().warning(f"JWT signature verification failed: {e}")
|
||||
raise
|
||||
|
||||
def _decode_jwt_safely(token: str) -> dict:
|
||||
"""
|
||||
Safely decode and verify a JWT token with proper signature verification.
|
||||
Falls back to no verification only if JWT secret is not configured.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode
|
||||
|
||||
Returns:
|
||||
dict: The decoded JWT payload
|
||||
|
||||
Raises:
|
||||
PyJWTError: If token is invalid or signature verification fails
|
||||
"""
|
||||
jwt_secret = config.SUPABASE_JWT_SECRET
|
||||
|
||||
if jwt_secret:
|
||||
# Production mode: Verify signature
|
||||
structlog.get_logger().debug("Verifying JWT signature")
|
||||
return _verify_jwt_signature(token, jwt_secret)
|
||||
else:
|
||||
# Development mode: Log warning and decode without verification
|
||||
structlog.get_logger().warning(
|
||||
"JWT_SECRET not configured - using insecure JWT decoding. "
|
||||
"This should only be used in development!"
|
||||
)
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get user_id from account_id with Redis caching for performance
|
||||
|
@ -58,7 +142,7 @@ async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
# This function extracts the user ID from Supabase JWT
|
||||
async def get_current_user_id_from_jwt(request: Request) -> str:
|
||||
async def verify_and_get_user_id_from_jwt(request: Request) -> str:
|
||||
"""
|
||||
Extract and verify the user ID from the JWT in the Authorization header or API key.
|
||||
|
||||
|
@ -149,7 +233,7 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
|
|||
token = auth_header.split(' ')[1]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
payload = _decode_jwt_safely(token)
|
||||
user_id = payload.get('sub')
|
||||
|
||||
if not user_id:
|
||||
|
@ -173,51 +257,6 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
|
|||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
async def get_account_id_from_thread(client, thread_id: str) -> str:
|
||||
"""
|
||||
Extract and verify the account ID from the thread.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
thread_id: The ID of the thread
|
||||
|
||||
Returns:
|
||||
str: The account ID associated with the thread
|
||||
|
||||
Raises:
|
||||
HTTPException: If the thread is not found or if there's an error
|
||||
"""
|
||||
try:
|
||||
response = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Thread not found"
|
||||
)
|
||||
|
||||
account_id = response.data[0].get('account_id')
|
||||
|
||||
if not account_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Thread has no associated account"
|
||||
)
|
||||
|
||||
return account_id
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Server is shutting down"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving thread information: {str(e)}"
|
||||
)
|
||||
|
||||
async def get_user_id_from_stream_auth(
|
||||
request: Request,
|
||||
|
@ -246,7 +285,7 @@ async def get_user_id_from_stream_auth(
|
|||
try:
|
||||
# First, try the standard authentication (handles both API keys and Authorization header)
|
||||
try:
|
||||
return await get_current_user_id_from_jwt(request)
|
||||
return await verify_and_get_user_id_from_jwt(request)
|
||||
except HTTPException:
|
||||
# If standard auth fails, try query parameter JWT for EventSource compatibility
|
||||
pass
|
||||
|
@ -254,8 +293,8 @@ async def get_user_id_from_stream_auth(
|
|||
# Try to get user_id from token in query param (for EventSource which can't set headers)
|
||||
if token:
|
||||
try:
|
||||
# For Supabase JWT, we just need to decode and extract the user ID
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
# For Supabase JWT, verify signature and extract the user ID
|
||||
payload = _decode_jwt_safely(token)
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
sentry.sentry.set_user({ "id": user_id })
|
||||
|
@ -289,7 +328,76 @@ async def get_user_id_from_stream_auth(
|
|||
detail=f"Error during authentication: {str(e)}"
|
||||
)
|
||||
|
||||
async def verify_thread_access(client, thread_id: str, user_id: str):
|
||||
async def get_optional_user_id(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract the user ID from the JWT in the Authorization header if present,
|
||||
but don't require authentication. Returns None if no valid token is found.
|
||||
|
||||
This function is used for endpoints that support both authenticated and
|
||||
unauthenticated access (like public projects).
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
|
||||
Returns:
|
||||
Optional[str]: The user ID extracted from the JWT, or None if no valid token
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return None
|
||||
|
||||
token = auth_header.split(' ')[1]
|
||||
|
||||
try:
|
||||
# For Supabase JWT, verify signature and extract the user ID
|
||||
payload = _decode_jwt_safely(token)
|
||||
|
||||
# Supabase stores the user ID in the 'sub' claim
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
sentry.sentry.set_user({ "id": user_id })
|
||||
structlog.contextvars.bind_contextvars(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return user_id
|
||||
except PyJWTError:
|
||||
return None
|
||||
|
||||
# Alias for consistency with other auth functions
|
||||
get_optional_current_user_id_from_jwt = get_optional_user_id
|
||||
|
||||
async def verify_and_get_agent_authorization(client, agent_id: str, user_id: str) -> dict:
|
||||
"""
|
||||
Verify that a user has access to a specific agent based on ownership.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
agent_id: The agent ID to check access for
|
||||
user_id: The user ID to check permissions for
|
||||
|
||||
Returns:
|
||||
dict: Agent data if access is granted
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user doesn't have access to the agent or agent doesn't exist
|
||||
"""
|
||||
try:
|
||||
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
|
||||
|
||||
if not agent_result.data:
|
||||
raise HTTPException(status_code=404, detail="Agent not found or access denied")
|
||||
|
||||
return agent_result.data[0]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to verify agent access")
|
||||
|
||||
async def verify_and_authorize_thread_access(client, thread_id: str, user_id: str):
|
||||
"""
|
||||
Verify that a user has access to a specific thread based on account membership.
|
||||
|
||||
|
@ -347,92 +455,137 @@ async def verify_thread_access(client, thread_id: str, user_id: str):
|
|||
detail=f"Error verifying thread access: {str(e)}"
|
||||
)
|
||||
|
||||
async def get_optional_user_id(request: Request) -> Optional[str]:
|
||||
# ============================================================================
|
||||
# FastAPI Dependency Functions for Combined Authentication + Authorization
|
||||
# ============================================================================
|
||||
|
||||
async def get_authorized_user_for_thread(
|
||||
thread_id: str,
|
||||
request: Request
|
||||
) -> str:
|
||||
"""
|
||||
Extract the user ID from the JWT in the Authorization header if present,
|
||||
but don't require authentication. Returns None if no valid token is found.
|
||||
|
||||
This function is used for endpoints that support both authenticated and
|
||||
unauthenticated access (like public projects).
|
||||
FastAPI dependency that verifies JWT and authorizes thread access.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID to authorize access for
|
||||
request: The FastAPI request object
|
||||
|
||||
Returns:
|
||||
Optional[str]: The user ID extracted from the JWT, or None if no valid token
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return None
|
||||
|
||||
token = auth_header.split(' ')[1]
|
||||
|
||||
try:
|
||||
# For Supabase JWT, we just need to decode and extract the user ID
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
# Supabase stores the user ID in the 'sub' claim
|
||||
user_id = payload.get('sub')
|
||||
if user_id:
|
||||
sentry.sentry.set_user({ "id": user_id })
|
||||
structlog.contextvars.bind_contextvars(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return user_id
|
||||
except PyJWTError:
|
||||
return None
|
||||
|
||||
# Alias for consistency with other auth functions
|
||||
get_optional_current_user_id_from_jwt = get_optional_user_id
|
||||
|
||||
async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)):
|
||||
if not config.KORTIX_ADMIN_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Admin API key not configured on server"
|
||||
)
|
||||
|
||||
if not x_admin_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Admin API key required. Include X-Admin-Api-Key header."
|
||||
)
|
||||
|
||||
if x_admin_api_key != config.KORTIX_ADMIN_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid admin API key"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def verify_agent_access(client, agent_id: str, user_id: str) -> dict:
|
||||
"""
|
||||
Verify that a user has access to a specific agent based on ownership.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
agent_id: The agent ID to check access for
|
||||
user_id: The user ID to check permissions for
|
||||
|
||||
Returns:
|
||||
dict: Agent data if access is granted
|
||||
str: The authenticated and authorized user ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user doesn't have access to the agent or agent doesn't exist
|
||||
HTTPException: If authentication fails or user lacks thread access
|
||||
"""
|
||||
try:
|
||||
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
|
||||
from services.supabase import DBConnection
|
||||
|
||||
# First, authenticate the user
|
||||
user_id = await verify_and_get_user_id_from_jwt(request)
|
||||
|
||||
# Then, authorize thread access
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
await verify_and_authorize_thread_access(client, thread_id, user_id)
|
||||
|
||||
return user_id
|
||||
|
||||
async def get_authorized_user_for_agent(
|
||||
agent_id: str,
|
||||
request: Request
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
FastAPI dependency that verifies JWT and authorizes agent access.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to authorize access for
|
||||
request: The FastAPI request object
|
||||
|
||||
if not agent_result.data:
|
||||
raise HTTPException(status_code=404, detail="Agent not found or access denied")
|
||||
Returns:
|
||||
tuple[str, dict]: The authenticated user ID and agent data
|
||||
|
||||
return agent_result.data[0]
|
||||
Raises:
|
||||
HTTPException: If authentication fails or user lacks agent access
|
||||
"""
|
||||
from services.supabase import DBConnection
|
||||
|
||||
# First, authenticate the user
|
||||
user_id = await verify_and_get_user_id_from_jwt(request)
|
||||
|
||||
# Then, authorize agent access and get agent data
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
|
||||
|
||||
return user_id, agent_data
|
||||
|
||||
class AuthorizedThreadAccess:
|
||||
"""
|
||||
FastAPI dependency that combines authentication and thread authorization.
|
||||
|
||||
Usage:
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
async def get_messages(
|
||||
thread_id: str,
|
||||
auth: AuthorizedThreadAccess = Depends()
|
||||
):
|
||||
user_id = auth.user_id # Authenticated and authorized user
|
||||
"""
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
class AuthorizedAgentAccess:
|
||||
"""
|
||||
FastAPI dependency that combines authentication and agent authorization.
|
||||
|
||||
Usage:
|
||||
@router.get("/agents/{agent_id}/config")
|
||||
async def get_agent_config(
|
||||
agent_id: str,
|
||||
auth: AuthorizedAgentAccess = Depends()
|
||||
):
|
||||
user_id = auth.user_id # Authenticated and authorized user
|
||||
agent_data = auth.agent_data # Agent data from authorization check
|
||||
"""
|
||||
def __init__(self, user_id: str, agent_data: dict):
|
||||
self.user_id = user_id
|
||||
self.agent_data = agent_data
|
||||
|
||||
async def require_thread_access(
|
||||
thread_id: str,
|
||||
request: Request
|
||||
) -> AuthorizedThreadAccess:
|
||||
"""
|
||||
FastAPI dependency that verifies JWT and authorizes thread access.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID from the path parameter
|
||||
request: The FastAPI request object
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to verify agent access")
|
||||
Returns:
|
||||
AuthorizedThreadAccess: Object containing authenticated user_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails or user lacks thread access
|
||||
"""
|
||||
user_id = await get_authorized_user_for_thread(thread_id, request)
|
||||
return AuthorizedThreadAccess(user_id)
|
||||
|
||||
async def require_agent_access(
|
||||
agent_id: str,
|
||||
request: Request
|
||||
) -> AuthorizedAgentAccess:
|
||||
"""
|
||||
FastAPI dependency that verifies JWT and authorizes agent access.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID from the path parameter
|
||||
request: The FastAPI request object
|
||||
|
||||
Returns:
|
||||
AuthorizedAgentAccess: Object containing user_id and agent_data
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails or user lacks agent access
|
||||
"""
|
||||
user_id, agent_data = await get_authorized_user_for_agent(agent_id, request)
|
||||
return AuthorizedAgentAccess(user_id, agent_data)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ Usage:
|
|||
|
||||
import os
|
||||
from enum import Enum
|
||||
from re import S
|
||||
from typing import Dict, Any, Optional, get_type_hints, Union
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
|
@ -277,6 +278,7 @@ class Configuration:
|
|||
SUPABASE_URL: str
|
||||
SUPABASE_ANON_KEY: str
|
||||
SUPABASE_SERVICE_ROLE_KEY: str
|
||||
SUPABASE_JWT_SECRET: str
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST: str
|
||||
|
|
26
setup.py
26
setup.py
|
@ -119,6 +119,7 @@ def load_existing_env_vars():
|
|||
"SUPABASE_SERVICE_ROLE_KEY": backend_env.get(
|
||||
"SUPABASE_SERVICE_ROLE_KEY", ""
|
||||
),
|
||||
"SUPABASE_JWT_SECRET": backend_env.get("SUPABASE_JWT_SECRET", ""),
|
||||
},
|
||||
"daytona": {
|
||||
"DAYTONA_API_KEY": backend_env.get("DAYTONA_API_KEY", ""),
|
||||
|
@ -285,8 +286,17 @@ class SetupWizard:
|
|||
config_items = []
|
||||
|
||||
# Check Supabase
|
||||
if self.env_vars["supabase"]["SUPABASE_URL"]:
|
||||
config_items.append(f"{Colors.GREEN}✓{Colors.ENDC} Supabase")
|
||||
supabase_complete = (
|
||||
self.env_vars["supabase"]["SUPABASE_URL"] and
|
||||
self.env_vars["supabase"]["SUPABASE_ANON_KEY"] and
|
||||
self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"]
|
||||
)
|
||||
supabase_secure = self.env_vars["supabase"]["SUPABASE_JWT_SECRET"]
|
||||
|
||||
if supabase_complete and supabase_secure:
|
||||
config_items.append(f"{Colors.GREEN}✓{Colors.ENDC} Supabase (secure)")
|
||||
elif supabase_complete:
|
||||
config_items.append(f"{Colors.YELLOW}⚠{Colors.ENDC} Supabase (missing JWT secret)")
|
||||
else:
|
||||
config_items.append(f"{Colors.YELLOW}○{Colors.ENDC} Supabase")
|
||||
|
||||
|
@ -600,8 +610,12 @@ class SetupWizard:
|
|||
"You'll need a Supabase project. Visit https://supabase.com/dashboard/projects to create one."
|
||||
)
|
||||
print_info(
|
||||
"In your project settings, go to 'API' to find the required information."
|
||||
"In your project settings, go to 'API' to find the required information:"
|
||||
)
|
||||
print_info(" - Project URL (at the top)")
|
||||
print_info(" - anon public key (under 'Project API keys')")
|
||||
print_info(" - service_role secret key (under 'Project API keys')")
|
||||
print_info(" - JWT Secret (under 'JWT Settings' - critical for security!)")
|
||||
input("Press Enter to continue once you have your project details...")
|
||||
|
||||
self.env_vars["supabase"]["SUPABASE_URL"] = self._get_input(
|
||||
|
@ -622,6 +636,12 @@ class SetupWizard:
|
|||
"This does not look like a valid key. It should be at least 10 characters.",
|
||||
default_value=self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"],
|
||||
)
|
||||
self.env_vars["supabase"]["SUPABASE_JWT_SECRET"] = self._get_input(
|
||||
"Enter your Supabase JWT secret (for signature verification): ",
|
||||
validate_api_key,
|
||||
"This does not look like a valid JWT secret. It should be at least 10 characters.",
|
||||
default_value=self.env_vars["supabase"]["SUPABASE_JWT_SECRET"],
|
||||
)
|
||||
print_success("Supabase information saved.")
|
||||
|
||||
def collect_daytona_info(self):
|
||||
|
|
Loading…
Reference in New Issue