suna/backend/agent/versioning/api/routes.py

286 lines
11 KiB
Python
Raw Normal View History

2025-07-14 01:56:24 +08:00
from fastapi import APIRouter, HTTPException, Depends
from typing import List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
from ..domain.entities import AgentId, VersionId, UserId
from ..services.version_service import VersionService
from ..services.exceptions import (
VersionNotFoundError, AgentNotFoundError, UnauthorizedError,
InvalidVersionError
)
from ..infrastructure.dependencies import get_version_service
from utils.auth_utils import get_current_user_id_from_jwt
router = APIRouter(prefix="/agents/{agent_id}/versions", tags=["versions"])
class VersionResponse(BaseModel):
version_id: str
agent_id: str
version_number: int
version_name: str
system_prompt: str
configured_mcps: List[dict] = Field(default_factory=list)
custom_mcps: List[dict] = Field(default_factory=list)
agentpress_tools: dict = Field(default_factory=dict)
is_active: bool
created_at: datetime
updated_at: datetime
created_by: str
change_description: Optional[str] = None
class CreateVersionRequest(BaseModel):
system_prompt: str
configured_mcps: List[dict] = Field(default_factory=list)
custom_mcps: List[dict] = Field(default_factory=list)
agentpress_tools: dict = Field(default_factory=dict)
version_name: Optional[str] = None
description: Optional[str] = None
class UpdateVersionDetailsRequest(BaseModel):
version_name: Optional[str] = None
change_description: Optional[str] = None
2025-07-14 01:56:24 +08:00
class CompareVersionsResponse(BaseModel):
version1: VersionResponse
version2: VersionResponse
differences: List[dict]
@router.get("", response_model=List[VersionResponse])
async def get_versions(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
agent_id_obj = AgentId.from_string(agent_id)
user_id_obj = UserId.from_string(user_id)
versions = await version_service.get_all_versions(agent_id_obj, user_id_obj)
return [
VersionResponse(**version.to_dict())
for version in versions
]
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except AgentNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to fetch versions")
@router.post("", response_model=VersionResponse)
async def create_version(
agent_id: str,
request: CreateVersionRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
2025-07-23 00:11:10 +08:00
from services.supabase import DBConnection
from utils.logger import logger
from agent.config_helper import extract_agent_config
db = DBConnection()
client = await db.client
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).maybe_single().execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found")
2025-07-14 01:56:24 +08:00
agent_id_obj = AgentId.from_string(agent_id)
user_id_obj = UserId.from_string(user_id)
version = await version_service.create_version(
agent_id=agent_id_obj,
user_id=user_id_obj,
system_prompt=request.system_prompt,
configured_mcps=request.configured_mcps,
custom_mcps=request.custom_mcps,
agentpress_tools=request.agentpress_tools,
version_name=request.version_name,
change_description=request.description
)
return VersionResponse(**version.to_dict())
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except AgentNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
import traceback
from utils.logger import logger
logger.error(f"Failed to create version: {str(e)}")
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Failed to create version: {str(e)}")
@router.get("/{version_id}", response_model=VersionResponse)
async def get_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
agent_id_obj = AgentId.from_string(agent_id)
version_id_obj = VersionId.from_string(version_id)
user_id_obj = UserId.from_string(user_id)
version = await version_service.get_version(
agent_id_obj, version_id_obj, user_id_obj
)
return VersionResponse(**version.to_dict())
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except VersionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to fetch version")
@router.put("/{version_id}/activate")
async def activate_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
2025-07-23 00:11:10 +08:00
from services.supabase import DBConnection
from utils.logger import logger
db = DBConnection()
client = await db.client
agent_result = await client.table('agents').select('metadata').eq('agent_id', agent_id).eq('account_id', user_id).maybe_single().execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent_metadata = agent_result.data.get('metadata', {})
is_suna_agent = agent_metadata.get('is_suna_default', False)
restrictions = agent_metadata.get('restrictions', {})
if is_suna_agent:
logger.warning(f"Version activation attempt on Suna default agent {agent_id} by user {user_id} for version {version_id}")
logger.info(f"Allowing version activation for Suna agent {agent_id} - monitoring for security compliance")
2025-07-14 01:56:24 +08:00
agent_id_obj = AgentId.from_string(agent_id)
version_id_obj = VersionId.from_string(version_id)
user_id_obj = UserId.from_string(user_id)
await version_service.activate_version(
agent_id_obj, version_id_obj, user_id_obj
)
2025-07-23 00:11:10 +08:00
if is_suna_agent:
logger.info(f"Successfully activated version {version_id} for Suna agent {agent_id} by user {user_id}")
2025-07-14 01:56:24 +08:00
return {"message": "Version activated successfully"}
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except VersionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except InvalidVersionError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to activate version")
@router.get("/compare/{version1_id}/{version2_id}", response_model=CompareVersionsResponse)
async def compare_versions(
agent_id: str,
version1_id: str,
version2_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
agent_id_obj = AgentId.from_string(agent_id)
version1_id_obj = VersionId.from_string(version1_id)
version2_id_obj = VersionId.from_string(version2_id)
user_id_obj = UserId.from_string(user_id)
result = await version_service.compare_versions(
agent_id_obj, version1_id_obj, version2_id_obj, user_id_obj
)
return CompareVersionsResponse(**result)
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except VersionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to compare versions")
@router.post("/{version_id}/rollback", response_model=VersionResponse)
async def rollback_to_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
agent_id_obj = AgentId.from_string(agent_id)
version_id_obj = VersionId.from_string(version_id)
user_id_obj = UserId.from_string(user_id)
new_version = await version_service.rollback_to_version(
agent_id_obj, version_id_obj, user_id_obj
)
return VersionResponse(**new_version.to_dict())
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except VersionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to rollback version")
@router.put("/{version_id}/details", response_model=VersionResponse)
async def update_version_details(
agent_id: str,
version_id: str,
request: UpdateVersionDetailsRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
agent_id_obj = AgentId.from_string(agent_id)
version_id_obj = VersionId.from_string(version_id)
user_id_obj = UserId.from_string(user_id)
updated_version = await version_service.update_version_details(
agent_id=agent_id_obj,
version_id=version_id_obj,
user_id=user_id_obj,
version_name=request.version_name,
change_description=request.change_description
)
return VersionResponse(**updated_version.to_dict())
except UnauthorizedError as e:
raise HTTPException(status_code=403, detail=str(e))
except VersionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
from utils.logger import logger
logger.error(f"Failed to update version details: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to update version details")