mirror of https://github.com/kortix-ai/suna.git
212 lines
7.5 KiB
Python
212 lines
7.5 KiB
Python
|
from fastapi import APIRouter, HTTPException, Depends
|
||
|
from typing import List, Optional
|
||
|
from pydantic import BaseModel, Field
|
||
|
from datetime import datetime
|
||
|
|
||
|
from ..domain.entities import AgentId, VersionId, UserId
|
||
|
from ..services.version_service import VersionService
|
||
|
from ..services.exceptions import (
|
||
|
VersionNotFoundError, AgentNotFoundError, UnauthorizedError,
|
||
|
InvalidVersionError
|
||
|
)
|
||
|
from ..infrastructure.dependencies import get_version_service
|
||
|
from utils.auth_utils import get_current_user_id_from_jwt
|
||
|
|
||
|
|
||
|
router = APIRouter(prefix="/agents/{agent_id}/versions", tags=["versions"])
|
||
|
|
||
|
|
||
|
class VersionResponse(BaseModel):
|
||
|
version_id: str
|
||
|
agent_id: str
|
||
|
version_number: int
|
||
|
version_name: str
|
||
|
system_prompt: str
|
||
|
configured_mcps: List[dict] = Field(default_factory=list)
|
||
|
custom_mcps: List[dict] = Field(default_factory=list)
|
||
|
agentpress_tools: dict = Field(default_factory=dict)
|
||
|
is_active: bool
|
||
|
created_at: datetime
|
||
|
updated_at: datetime
|
||
|
created_by: str
|
||
|
change_description: Optional[str] = None
|
||
|
|
||
|
|
||
|
class CreateVersionRequest(BaseModel):
|
||
|
system_prompt: str
|
||
|
configured_mcps: List[dict] = Field(default_factory=list)
|
||
|
custom_mcps: List[dict] = Field(default_factory=list)
|
||
|
agentpress_tools: dict = Field(default_factory=dict)
|
||
|
version_name: Optional[str] = None
|
||
|
description: Optional[str] = None
|
||
|
|
||
|
|
||
|
class CompareVersionsResponse(BaseModel):
|
||
|
version1: VersionResponse
|
||
|
version2: VersionResponse
|
||
|
differences: List[dict]
|
||
|
|
||
|
|
||
|
@router.get("", response_model=List[VersionResponse])
|
||
|
async def get_versions(
|
||
|
agent_id: str,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
versions = await version_service.get_all_versions(agent_id_obj, user_id_obj)
|
||
|
|
||
|
return [
|
||
|
VersionResponse(**version.to_dict())
|
||
|
for version in versions
|
||
|
]
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except AgentNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail="Failed to fetch versions")
|
||
|
|
||
|
|
||
|
@router.post("", response_model=VersionResponse)
|
||
|
async def create_version(
|
||
|
agent_id: str,
|
||
|
request: CreateVersionRequest,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
version = await version_service.create_version(
|
||
|
agent_id=agent_id_obj,
|
||
|
user_id=user_id_obj,
|
||
|
system_prompt=request.system_prompt,
|
||
|
configured_mcps=request.configured_mcps,
|
||
|
custom_mcps=request.custom_mcps,
|
||
|
agentpress_tools=request.agentpress_tools,
|
||
|
version_name=request.version_name,
|
||
|
change_description=request.description
|
||
|
)
|
||
|
|
||
|
return VersionResponse(**version.to_dict())
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except AgentNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except ValueError as e:
|
||
|
raise HTTPException(status_code=400, detail=str(e))
|
||
|
except Exception as e:
|
||
|
import traceback
|
||
|
from utils.logger import logger
|
||
|
logger.error(f"Failed to create version: {str(e)}")
|
||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||
|
raise HTTPException(status_code=500, detail=f"Failed to create version: {str(e)}")
|
||
|
|
||
|
|
||
|
@router.get("/{version_id}", response_model=VersionResponse)
|
||
|
async def get_version(
|
||
|
agent_id: str,
|
||
|
version_id: str,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
version_id_obj = VersionId.from_string(version_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
version = await version_service.get_version(
|
||
|
agent_id_obj, version_id_obj, user_id_obj
|
||
|
)
|
||
|
|
||
|
return VersionResponse(**version.to_dict())
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except VersionNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail="Failed to fetch version")
|
||
|
|
||
|
|
||
|
@router.put("/{version_id}/activate")
|
||
|
async def activate_version(
|
||
|
agent_id: str,
|
||
|
version_id: str,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
version_id_obj = VersionId.from_string(version_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
await version_service.activate_version(
|
||
|
agent_id_obj, version_id_obj, user_id_obj
|
||
|
)
|
||
|
|
||
|
return {"message": "Version activated successfully"}
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except VersionNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except InvalidVersionError as e:
|
||
|
raise HTTPException(status_code=400, detail=str(e))
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail="Failed to activate version")
|
||
|
|
||
|
|
||
|
@router.get("/compare/{version1_id}/{version2_id}", response_model=CompareVersionsResponse)
|
||
|
async def compare_versions(
|
||
|
agent_id: str,
|
||
|
version1_id: str,
|
||
|
version2_id: str,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
version1_id_obj = VersionId.from_string(version1_id)
|
||
|
version2_id_obj = VersionId.from_string(version2_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
result = await version_service.compare_versions(
|
||
|
agent_id_obj, version1_id_obj, version2_id_obj, user_id_obj
|
||
|
)
|
||
|
|
||
|
return CompareVersionsResponse(**result)
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except VersionNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail="Failed to compare versions")
|
||
|
|
||
|
|
||
|
@router.post("/{version_id}/rollback", response_model=VersionResponse)
|
||
|
async def rollback_to_version(
|
||
|
agent_id: str,
|
||
|
version_id: str,
|
||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||
|
version_service: VersionService = Depends(get_version_service)
|
||
|
):
|
||
|
try:
|
||
|
agent_id_obj = AgentId.from_string(agent_id)
|
||
|
version_id_obj = VersionId.from_string(version_id)
|
||
|
user_id_obj = UserId.from_string(user_id)
|
||
|
|
||
|
new_version = await version_service.rollback_to_version(
|
||
|
agent_id_obj, version_id_obj, user_id_obj
|
||
|
)
|
||
|
|
||
|
return VersionResponse(**new_version.to_dict())
|
||
|
except UnauthorizedError as e:
|
||
|
raise HTTPException(status_code=403, detail=str(e))
|
||
|
except VersionNotFoundError as e:
|
||
|
raise HTTPException(status_code=404, detail=str(e))
|
||
|
except Exception as e:
|
||
|
raise HTTPException(status_code=500, detail="Failed to rollback version")
|