suna/backend/agent/versioning/version_service.py

454 lines
16 KiB
Python

import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional
from uuid import uuid4, UUID
from enum import Enum
from services.supabase import DBConnection
from utils.logger import logger
class VersionStatus(Enum):
ACTIVE = "active"
INACTIVE = "inactive"
ARCHIVED = "archived"
@dataclass
class AgentVersion:
version_id: str
agent_id: str
version_number: int
version_name: str
system_prompt: str
configured_mcps: List[Dict[str, Any]] = field(default_factory=list)
custom_mcps: List[Dict[str, Any]] = field(default_factory=list)
agentpress_tools: Dict[str, Any] = field(default_factory=dict)
is_active: bool = True
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
created_by: str = ""
change_description: Optional[str] = None
previous_version_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'version_id': self.version_id,
'agent_id': self.agent_id,
'version_number': self.version_number,
'version_name': self.version_name,
'system_prompt': self.system_prompt,
'configured_mcps': self.configured_mcps,
'custom_mcps': self.custom_mcps,
'agentpress_tools': self.agentpress_tools,
'is_active': self.is_active,
'status': VersionStatus.ACTIVE.value if self.is_active else VersionStatus.INACTIVE.value,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'created_by': self.created_by,
'change_description': self.change_description,
'previous_version_id': self.previous_version_id
}
class VersionServiceError(Exception):
pass
class VersionNotFoundError(VersionServiceError):
pass
class AgentNotFoundError(VersionServiceError):
pass
class UnauthorizedError(VersionServiceError):
pass
class InvalidVersionError(VersionServiceError):
pass
class VersionConflictError(VersionServiceError):
pass
class VersionService:
def __init__(self):
self.db = DBConnection()
async def _get_client(self):
return await self.db.client
async def _verify_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]:
if user_id == "system":
return True, True
client = await self._get_client()
owner_result = await client.table('agents').select('account_id').eq(
'agent_id', agent_id
).eq('account_id', user_id).execute()
is_owner = bool(owner_result.data)
public_result = await client.table('agents').select('is_public').eq(
'agent_id', agent_id
).execute()
is_public = bool(public_result.data and public_result.data[0].get('is_public', False))
return is_owner, is_public
async def _get_next_version_number(self, agent_id: str) -> int:
client = await self._get_client()
result = await client.table('agent_versions').select(
'version_number'
).eq('agent_id', agent_id).order(
'version_number', desc=True
).limit(1).execute()
if not result.data:
return 1
return result.data[0]['version_number'] + 1
async def _count_versions(self, agent_id: str) -> int:
client = await self._get_client()
result = await client.table('agent_versions').select(
'version_id', count='exact'
).eq('agent_id', agent_id).execute()
return result.count or 0
async def _update_agent_current_version(self, agent_id: str, version_id: str, version_count: int):
client = await self._get_client()
data = {
'current_version_id': version_id,
'version_count': version_count
}
result = await client.table('agents').update(data).eq(
'agent_id', agent_id
).execute()
if not result.data:
raise Exception("Failed to update agent current version")
def _version_from_db_row(self, row: Dict[str, Any]) -> AgentVersion:
config = row.get('config', {})
tools = config.get('tools', {})
return AgentVersion(
version_id=row['version_id'],
agent_id=row['agent_id'],
version_number=row['version_number'],
version_name=row['version_name'],
system_prompt=config.get('system_prompt', ''),
configured_mcps=tools.get('mcp', []),
custom_mcps=tools.get('custom_mcp', []),
agentpress_tools=tools.get('agentpress', {}),
is_active=row.get('is_active', False),
created_at=datetime.fromisoformat(row['created_at'].replace('Z', '+00:00')),
updated_at=datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00')),
created_by=row['created_by'],
change_description=row.get('change_description'),
previous_version_id=row.get('previous_version_id')
)
async def create_version(
self,
agent_id: str,
user_id: str,
system_prompt: str,
configured_mcps: List[Dict[str, Any]],
custom_mcps: List[Dict[str, Any]],
agentpress_tools: Dict[str, Any],
version_name: Optional[str] = None,
change_description: Optional[str] = None
) -> AgentVersion:
logger.info(f"Creating version for agent {agent_id}")
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to create versions for this agent")
client = await self._get_client()
agent_result = await client.table('agents').select('*').eq(
'agent_id', agent_id
).execute()
if not agent_result.data:
raise AgentNotFoundError(f"Agent {agent_id} not found")
version_number = await self._get_next_version_number(agent_id)
version_name = version_name or f"v{version_number}"
current_active_result = await client.table('agent_versions').select('*').eq(
'agent_id', agent_id
).eq('is_active', True).execute()
previous_version_id = None
if current_active_result.data:
previous_version_id = current_active_result.data[0]['version_id']
await client.table('agent_versions').update({
'is_active': False,
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('version_id', previous_version_id).execute()
version = AgentVersion(
version_id=str(uuid4()),
agent_id=agent_id,
version_number=version_number,
version_name=version_name,
system_prompt=system_prompt,
configured_mcps=configured_mcps,
custom_mcps=custom_mcps,
agentpress_tools=agentpress_tools,
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
created_by=user_id,
change_description=change_description,
previous_version_id=previous_version_id
)
data = {
'version_id': version.version_id,
'agent_id': version.agent_id,
'version_number': version.version_number,
'version_name': version.version_name,
'is_active': version.is_active,
'created_at': version.created_at.isoformat(),
'updated_at': version.updated_at.isoformat(),
'created_by': version.created_by,
'change_description': version.change_description,
'previous_version_id': version.previous_version_id,
'config': {
'system_prompt': version.system_prompt,
'tools': {
'agentpress': version.agentpress_tools,
'mcp': version.configured_mcps,
'custom_mcp': version.custom_mcps
}
}
}
result = await client.table('agent_versions').insert(data).execute()
if not result.data:
raise Exception("Failed to create version")
version_count = await self._count_versions(agent_id)
await self._update_agent_current_version(agent_id, version.version_id, version_count)
logger.info(f"Created version {version.version_name} for agent {agent_id}")
return version
async def get_version(self, agent_id: str, version_id: str, user_id: str) -> AgentVersion:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view this version")
client = await self._get_client()
result = await client.table('agent_versions').select('*').eq(
'version_id', version_id
).eq('agent_id', agent_id).execute()
if not result.data:
raise VersionNotFoundError(f"Version {version_id} not found")
return self._version_from_db_row(result.data[0])
async def get_active_version(self, agent_id: str, user_id: str = "system") -> Optional[AgentVersion]:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view this agent")
client = await self._get_client()
result = await client.table('agent_versions').select('*').eq(
'agent_id', agent_id
).eq('is_active', True).execute()
if not result.data:
return None
return self._version_from_db_row(result.data[0])
async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view versions")
client = await self._get_client()
result = await client.table('agent_versions').select('*').eq(
'agent_id', agent_id
).order('version_number', desc=True).execute()
versions = [self._version_from_db_row(row) for row in result.data]
return versions
async def activate_version(self, agent_id: str, version_id: str, user_id: str) -> None:
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to activate versions")
client = await self._get_client()
version_result = await client.table('agent_versions').select('*').eq(
'version_id', version_id
).eq('agent_id', agent_id).execute()
if not version_result.data:
raise VersionNotFoundError(f"Version {version_id} not found")
version = version_result.data[0]
await client.table('agent_versions').update({
'is_active': False,
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('agent_id', agent_id).eq('is_active', True).execute()
await client.table('agent_versions').update({
'is_active': True,
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('version_id', version_id).execute()
version_count = await self._count_versions(agent_id)
await self._update_agent_current_version(agent_id, version_id, version_count)
logger.info(f"Activated version {version['version_name']} for agent {agent_id}")
async def compare_versions(
self,
agent_id: str,
version1_id: str,
version2_id: str,
user_id: str
) -> Dict[str, Any]:
version1 = await self.get_version(agent_id, version1_id, user_id)
version2 = await self.get_version(agent_id, version2_id, user_id)
differences = self._calculate_differences(version1, version2)
return {
'version1': version1.to_dict(),
'version2': version2.to_dict(),
'differences': differences
}
def _calculate_differences(self, v1: AgentVersion, v2: AgentVersion) -> List[Dict[str, Any]]:
differences = []
if v1.system_prompt != v2.system_prompt:
differences.append({
'field': 'system_prompt',
'type': 'modified',
'old_value': v1.system_prompt,
'new_value': v2.system_prompt
})
v1_tools = set(v1.agentpress_tools.keys())
v2_tools = set(v2.agentpress_tools.keys())
for tool in v2_tools - v1_tools:
differences.append({
'field': f'tool.{tool}',
'type': 'added',
'new_value': v2.agentpress_tools[tool]
})
for tool in v1_tools - v2_tools:
differences.append({
'field': f'tool.{tool}',
'type': 'removed',
'old_value': v1.agentpress_tools[tool]
})
for tool in v1_tools & v2_tools:
if v1.agentpress_tools[tool] != v2.agentpress_tools[tool]:
differences.append({
'field': f'tool.{tool}',
'type': 'modified',
'old_value': v1.agentpress_tools[tool],
'new_value': v2.agentpress_tools[tool]
})
return differences
async def rollback_to_version(
self,
agent_id: str,
version_id: str,
user_id: str
) -> AgentVersion:
version_to_restore = await self.get_version(agent_id, version_id, user_id)
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to rollback versions")
new_version = await self.create_version(
agent_id=agent_id,
user_id=user_id,
system_prompt=version_to_restore.system_prompt,
configured_mcps=version_to_restore.configured_mcps,
custom_mcps=version_to_restore.custom_mcps,
agentpress_tools=version_to_restore.agentpress_tools,
change_description=f"Rolled back to version {version_to_restore.version_name}"
)
return new_version
async def update_version_details(
self,
agent_id: str,
version_id: str,
user_id: str,
version_name: Optional[str] = None,
change_description: Optional[str] = None
) -> AgentVersion:
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to update this version")
client = await self._get_client()
version_result = await client.table('agent_versions').select('*').eq(
'version_id', version_id
).eq('agent_id', agent_id).execute()
if not version_result.data:
raise VersionNotFoundError(f"Version {version_id} not found")
update_data = {
'updated_at': datetime.now(timezone.utc).isoformat()
}
if version_name is not None:
update_data['version_name'] = version_name
if change_description is not None:
update_data['change_description'] = change_description
result = await client.table('agent_versions').update(update_data).eq(
'version_id', version_id
).execute()
if not result.data:
raise Exception("Failed to update version")
return self._version_from_db_row(result.data[0])
_version_service_instance = None
async def get_version_service() -> VersionService:
global _version_service_instance
if _version_service_instance is None:
_version_service_instance = VersionService()
return _version_service_instance