mirror of https://github.com/kortix-ai/suna.git
379 lines
16 KiB
Python
379 lines
16 KiB
Python
"""
|
|
Agent Version Manager - Comprehensive module for managing agent versions
|
|
"""
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from datetime import datetime
|
|
from uuid import UUID
|
|
import json
|
|
from pydantic import BaseModel, Field, validator
|
|
from utils.logger import logger
|
|
from services.supabase import DBConnection
|
|
from fastapi import HTTPException
|
|
|
|
|
|
class VersionData(BaseModel):
|
|
"""Model for version data with proper validation"""
|
|
system_prompt: str
|
|
configured_mcps: List[Dict[str, Any]] = Field(default_factory=list)
|
|
custom_mcps: List[Dict[str, Any]] = Field(default_factory=list)
|
|
agentpress_tools: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@validator('custom_mcps', pre=True)
|
|
def normalize_custom_mcps(cls, v):
|
|
"""Ensure custom MCPs have consistent structure"""
|
|
if not isinstance(v, list):
|
|
return []
|
|
|
|
normalized = []
|
|
for mcp in v:
|
|
if isinstance(mcp, dict):
|
|
normalized.append({
|
|
'name': mcp.get('name', 'Unnamed MCP'),
|
|
'type': mcp.get('type') or mcp.get('customType', 'sse'),
|
|
'customType': mcp.get('customType') or mcp.get('type', 'sse'),
|
|
'config': mcp.get('config', {}),
|
|
'enabledTools': mcp.get('enabledTools') or mcp.get('enabled_tools', [])
|
|
})
|
|
return normalized
|
|
|
|
@validator('configured_mcps', pre=True)
|
|
def normalize_configured_mcps(cls, v):
|
|
"""Ensure configured MCPs are always a list"""
|
|
return v if isinstance(v, list) else []
|
|
|
|
@validator('agentpress_tools', pre=True)
|
|
def normalize_agentpress_tools(cls, v):
|
|
"""Ensure agentpress tools are always a dict"""
|
|
return v if isinstance(v, dict) else {}
|
|
|
|
|
|
class AgentVersionManager:
|
|
"""Manager class for handling all agent version operations"""
|
|
|
|
def __init__(self):
|
|
self.logger = logger
|
|
self.db = DBConnection()
|
|
|
|
async def get_version(self, agent_id: str, version_id: str, user_id: str) -> Dict[str, Any]:
|
|
"""Get a specific version with proper data normalization"""
|
|
client = await self.db.client
|
|
|
|
# Verify access
|
|
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute()
|
|
if not agent_result.data:
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
|
|
agent = agent_result.data[0]
|
|
if agent['account_id'] != user_id and not agent.get('is_public', False):
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
# Get version
|
|
version_result = await client.table('agent_versions').select('*').eq('version_id', version_id).eq('agent_id', agent_id).execute()
|
|
|
|
if not version_result.data:
|
|
raise HTTPException(status_code=404, detail="Version not found")
|
|
|
|
version_data = version_result.data[0]
|
|
return self._normalize_version_data(version_data)
|
|
|
|
async def get_all_versions(self, agent_id: str, user_id: str) -> List[Dict[str, Any]]:
|
|
"""Get all versions for an agent"""
|
|
client = await self.db.client
|
|
|
|
# Verify access
|
|
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute()
|
|
if not agent_result.data:
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
|
|
agent = agent_result.data[0]
|
|
if agent['account_id'] != user_id and not agent.get('is_public', False):
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
# Get all versions
|
|
versions_result = await client.table('agent_versions').select('*').eq('agent_id', agent_id).order('version_number', desc=True).execute()
|
|
|
|
return [self._normalize_version_data(v) for v in versions_result.data]
|
|
|
|
async def create_version(
|
|
self,
|
|
agent_id: str,
|
|
version_data: VersionData,
|
|
user_id: str,
|
|
version_name: Optional[str] = None,
|
|
description: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""Create a new version with proper validation"""
|
|
client = await self.db.client
|
|
|
|
# Verify ownership
|
|
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
|
|
if not agent_result.data:
|
|
raise HTTPException(status_code=404, detail="Agent not found or access denied")
|
|
|
|
# Get next version number
|
|
versions_result = await client.table('agent_versions').select('version_number').eq('agent_id', agent_id).order('version_number', desc=True).limit(1).execute()
|
|
next_version_number = 1
|
|
if versions_result.data:
|
|
next_version_number = versions_result.data[0]['version_number'] + 1
|
|
|
|
# Create version
|
|
new_version_data = {
|
|
"agent_id": agent_id,
|
|
"version_number": next_version_number,
|
|
"version_name": version_name or f"v{next_version_number}",
|
|
"system_prompt": version_data.system_prompt,
|
|
"configured_mcps": version_data.configured_mcps,
|
|
"custom_mcps": version_data.custom_mcps,
|
|
"agentpress_tools": version_data.agentpress_tools,
|
|
"is_active": True,
|
|
"created_by": user_id,
|
|
"change_description": description
|
|
}
|
|
|
|
# Build unified config
|
|
config = self._build_unified_config(version_data)
|
|
new_version_data['config'] = config
|
|
|
|
try:
|
|
version_result = await client.table('agent_versions').insert(new_version_data).execute()
|
|
|
|
if not version_result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to create version")
|
|
|
|
version = version_result.data[0]
|
|
|
|
# Update agent's current version
|
|
await client.table('agents').update({
|
|
'current_version_id': version['version_id'],
|
|
'version_count': next_version_number
|
|
}).eq('agent_id', agent_id).execute()
|
|
|
|
self.logger.info(f"Created version {version['version_name']} for agent {agent_id}")
|
|
return self._normalize_version_data(version)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error creating version: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to create version: {str(e)}")
|
|
|
|
async def auto_create_version_on_config_change(
|
|
self,
|
|
agent_id: str,
|
|
user_id: str,
|
|
change_description: str = "Auto-saved configuration changes"
|
|
) -> Optional[str]:
|
|
try:
|
|
db = DBConnection()
|
|
current_agent = await db.fetch_one(
|
|
"SELECT * FROM agents WHERE agent_id = %s", (agent_id,)
|
|
)
|
|
|
|
if not current_agent:
|
|
logger.warning(f"Agent {agent_id} not found for auto-versioning")
|
|
return None
|
|
current_version = None
|
|
if current_agent['current_version_id']:
|
|
current_version = await db.fetch_one(
|
|
"SELECT * FROM agent_versions WHERE version_id = %s",
|
|
(current_agent['current_version_id'],)
|
|
)
|
|
|
|
current_config = {
|
|
'system_prompt': current_agent['system_prompt'],
|
|
'configured_mcps': current_agent['configured_mcps'] or [],
|
|
'custom_mcps': current_agent['custom_mcps'] or [],
|
|
'agentpress_tools': current_agent['agentpress_tools'] or {}
|
|
}
|
|
|
|
version_config = None
|
|
if current_version:
|
|
version_config = {
|
|
'system_prompt': current_version['system_prompt'],
|
|
'configured_mcps': current_version['configured_mcps'] or [],
|
|
'custom_mcps': current_version['custom_mcps'] or [],
|
|
'agentpress_tools': current_version['agentpress_tools'] or {}
|
|
}
|
|
|
|
if version_config and current_config == version_config:
|
|
logger.info(f"No configuration changes detected for agent {agent_id}")
|
|
return None
|
|
|
|
logger.info(f"Configuration changes detected for agent {agent_id}, creating auto-version")
|
|
|
|
version_data = VersionData(
|
|
system_prompt=current_config['system_prompt'],
|
|
configured_mcps=current_config['configured_mcps'],
|
|
custom_mcps=current_config['custom_mcps'],
|
|
agentpress_tools=current_config['agentpress_tools']
|
|
)
|
|
|
|
new_version = await self.create_version(
|
|
agent_id=agent_id,
|
|
version_data=version_data,
|
|
user_id=user_id,
|
|
description=change_description
|
|
)
|
|
|
|
if new_version:
|
|
await self.activate_version(agent_id, new_version['version_id'], user_id)
|
|
logger.info(f"Auto-created and activated version {new_version['version_name']} for agent {agent_id}")
|
|
return new_version['version_id']
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to auto-create version for agent {agent_id}: {e}")
|
|
return None
|
|
|
|
async def activate_version(self, agent_id: str, version_id: str, user_id: str) -> None:
|
|
"""Activate a specific version"""
|
|
client = await self.db.client
|
|
|
|
# Verify ownership
|
|
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
|
|
if not agent_result.data:
|
|
raise HTTPException(status_code=404, detail="Agent not found or access denied")
|
|
|
|
# Verify version exists
|
|
version_result = await client.table('agent_versions').select('*').eq('version_id', version_id).eq('agent_id', agent_id).execute()
|
|
if not version_result.data:
|
|
raise HTTPException(status_code=404, detail="Version not found")
|
|
|
|
# Update agent's current version
|
|
await client.table('agents').update({
|
|
'current_version_id': version_id
|
|
}).eq('agent_id', agent_id).execute()
|
|
|
|
self.logger.info(f"Activated version {version_id} for agent {agent_id}")
|
|
|
|
async def compare_versions(
|
|
self,
|
|
agent_id: str,
|
|
version1_id: str,
|
|
version2_id: str,
|
|
user_id: str
|
|
) -> Dict[str, Any]:
|
|
"""Compare two versions"""
|
|
version1 = await self.get_version(agent_id, version1_id, user_id)
|
|
version2 = await self.get_version(agent_id, version2_id, user_id)
|
|
|
|
return {
|
|
'version1': version1,
|
|
'version2': version2,
|
|
'differences': self._calculate_differences(version1, version2)
|
|
}
|
|
|
|
def _normalize_version_data(self, version_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Normalize version data to ensure consistent structure"""
|
|
# Handle custom MCPs normalization
|
|
custom_mcps = version_data.get('custom_mcps', [])
|
|
if custom_mcps is None:
|
|
custom_mcps = []
|
|
elif isinstance(custom_mcps, list):
|
|
custom_mcps = [
|
|
{
|
|
'name': mcp.get('name', 'Unnamed MCP'),
|
|
'type': mcp.get('type') or mcp.get('customType', 'sse'),
|
|
'customType': mcp.get('customType') or mcp.get('type', 'sse'),
|
|
'config': mcp.get('config', {}),
|
|
'enabledTools': mcp.get('enabledTools') or mcp.get('enabled_tools', [])
|
|
}
|
|
for mcp in custom_mcps
|
|
if isinstance(mcp, dict)
|
|
]
|
|
|
|
# Handle configured MCPs
|
|
configured_mcps = version_data.get('configured_mcps', [])
|
|
if configured_mcps is None:
|
|
configured_mcps = []
|
|
elif not isinstance(configured_mcps, list):
|
|
configured_mcps = []
|
|
|
|
# Handle agentpress tools
|
|
agentpress_tools = version_data.get('agentpress_tools', {})
|
|
if not isinstance(agentpress_tools, dict):
|
|
agentpress_tools = {}
|
|
|
|
return {
|
|
'version_id': version_data['version_id'],
|
|
'agent_id': version_data['agent_id'],
|
|
'version_number': version_data['version_number'],
|
|
'version_name': version_data['version_name'],
|
|
'system_prompt': version_data.get('system_prompt', ''),
|
|
'configured_mcps': configured_mcps,
|
|
'custom_mcps': custom_mcps,
|
|
'agentpress_tools': agentpress_tools,
|
|
'is_active': version_data.get('is_active', True),
|
|
'created_at': version_data['created_at'],
|
|
'updated_at': version_data.get('updated_at', version_data['created_at']),
|
|
'created_by': version_data.get('created_by'),
|
|
'change_description': version_data.get('change_description'),
|
|
'config': version_data.get('config', {})
|
|
}
|
|
|
|
def _build_unified_config(self, version_data: VersionData) -> Dict[str, Any]:
|
|
"""Build unified config object"""
|
|
simplified_tools = {}
|
|
for tool_name, tool_config in version_data.agentpress_tools.items():
|
|
if isinstance(tool_config, dict):
|
|
simplified_tools[tool_name] = tool_config.get('enabled', False)
|
|
elif isinstance(tool_config, bool):
|
|
simplified_tools[tool_name] = tool_config
|
|
|
|
return {
|
|
'system_prompt': version_data.system_prompt,
|
|
'tools': {
|
|
'agentpress': simplified_tools,
|
|
'mcp': version_data.configured_mcps,
|
|
'custom_mcp': version_data.custom_mcps
|
|
}
|
|
}
|
|
|
|
def _calculate_differences(self, version1: Dict[str, Any], version2: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Calculate differences between two versions"""
|
|
differences = {}
|
|
|
|
# Check system prompt
|
|
if version1['system_prompt'] != version2['system_prompt']:
|
|
differences['system_prompt'] = {
|
|
'changed': True,
|
|
'version1': version1['system_prompt'][:100] + '...' if len(version1['system_prompt']) > 100 else version1['system_prompt'],
|
|
'version2': version2['system_prompt'][:100] + '...' if len(version2['system_prompt']) > 100 else version2['system_prompt']
|
|
}
|
|
|
|
# Check tools
|
|
v1_tools = set(version1['agentpress_tools'].keys())
|
|
v2_tools = set(version2['agentpress_tools'].keys())
|
|
|
|
tools_added = v2_tools - v1_tools
|
|
tools_removed = v1_tools - v2_tools
|
|
tools_changed = []
|
|
|
|
for tool in v1_tools & v2_tools:
|
|
if version1['agentpress_tools'][tool] != version2['agentpress_tools'][tool]:
|
|
tools_changed.append(tool)
|
|
|
|
if tools_added or tools_removed or tools_changed:
|
|
differences['agentpress_tools'] = {
|
|
'added': list(tools_added),
|
|
'removed': list(tools_removed),
|
|
'changed': tools_changed
|
|
}
|
|
|
|
# Check MCPs
|
|
if json.dumps(version1['configured_mcps'], sort_keys=True) != json.dumps(version2['configured_mcps'], sort_keys=True):
|
|
differences['configured_mcps'] = {
|
|
'version1_count': len(version1['configured_mcps']),
|
|
'version2_count': len(version2['configured_mcps'])
|
|
}
|
|
|
|
if json.dumps(version1['custom_mcps'], sort_keys=True) != json.dumps(version2['custom_mcps'], sort_keys=True):
|
|
differences['custom_mcps'] = {
|
|
'version1_count': len(version1['custom_mcps']),
|
|
'version2_count': len(version2['custom_mcps'])
|
|
}
|
|
|
|
return differences
|
|
|
|
|
|
# Singleton instance
|
|
version_manager = AgentVersionManager() |