suna/backend/templates/services/creation_service.py

171 lines
7.7 KiB
Python

from datetime import datetime, timezone
from typing import Dict, Any, Optional, List
from uuid import uuid4
from ..domain.entities import AgentTemplate, ConfigType, MCPRequirementValue
from ..domain.exceptions import TemplateNotFoundError, TemplateAccessDeniedError
from ..repositories.template_repository import TemplateRepository
from ..repositories.agent_repository import AgentRepository
from ..support.validator import TemplateValidator
from ..support.factory import MCPRequirementFactory
from ..protocols import VersionManager, Logger
class TemplateCreationService:
def __init__(
self,
template_repo: TemplateRepository,
agent_repo: AgentRepository,
version_manager: Optional[VersionManager],
validator: TemplateValidator,
factory: MCPRequirementFactory,
logger: Logger
):
self._template_repo = template_repo
self._agent_repo = agent_repo
self._version_manager = version_manager
self._validator = validator
self._factory = factory
self._logger = logger
async def create_from_agent(
self,
agent_id: str,
creator_id: str,
make_public: bool = False,
tags: Optional[List[str]] = None
) -> str:
self._logger.info(f"Creating template from agent {agent_id} for user {creator_id}")
agent = await self._agent_repo.find_by_id(agent_id)
if not agent:
raise TemplateNotFoundError("Agent not found")
if agent['account_id'] != creator_id:
raise TemplateAccessDeniedError("You can only create templates from your own agents")
version_data = await self._get_version_data(agent, creator_id)
config = self._extract_config(agent, version_data)
mcp_requirements = self._build_requirements(
config['configured_mcps'],
config['custom_mcps']
)
template = AgentTemplate(
template_id=str(uuid4()),
creator_id=creator_id,
name=agent['name'],
description=agent.get('description'),
system_prompt=config['system_prompt'],
mcp_requirements=mcp_requirements,
agentpress_tools=config['agentpress_tools'],
tags=tags or [],
is_public=make_public,
marketplace_published_at=datetime.now(timezone.utc) if make_public else None,
avatar=agent.get('avatar'),
avatar_color=agent.get('avatar_color'),
metadata={
'source_agent_id': agent_id,
'source_version_id': agent.get('current_version_id'),
'source_version_name': version_data.get('version_name', 'v1') if version_data else 'legacy'
}
)
saved_template = await self._template_repo.save(template)
self._logger.info(f"Successfully created template {saved_template.template_id}")
return saved_template.template_id
async def _get_version_data(self, agent: Dict[str, Any], user_id: str) -> Optional[Dict[str, Any]]:
if not self._version_manager or not agent.get('current_version_id'):
return None
try:
return await self._version_manager.get_version(
agent_id=agent['agent_id'],
version_id=agent['current_version_id'],
user_id=user_id
)
except Exception as e:
self._logger.warning(f"Failed to get version data: {e}")
return None
def _extract_config(self, agent: Dict[str, Any], version_data: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if version_data:
self._logger.info(f"Using version config for template creation")
configured_mcps = version_data.get('configured_mcps', [])
custom_mcps = version_data.get('custom_mcps', [])
self._logger.info(f"Version data - configured_mcps: {len(configured_mcps)}, custom_mcps: {len(custom_mcps)}")
return {
'system_prompt': version_data.get('system_prompt', ''),
'agentpress_tools': version_data.get('agentpress_tools', {}),
'configured_mcps': configured_mcps,
'custom_mcps': custom_mcps
}
else:
self._logger.info(f"Using legacy config for template creation")
config = agent.get('config', {})
if config and config != {}:
self._logger.info(f"Found unified config, extracting MCP data from it")
tools = config.get('tools', {})
configured_mcps = tools.get('mcp', [])
custom_mcps = tools.get('custom_mcp', [])
system_prompt = config.get('system_prompt', agent.get('system_prompt', ''))
agentpress_tools = tools.get('agentpress', agent.get('agentpress_tools', {}))
else:
self._logger.info(f"No unified config, using legacy columns")
configured_mcps = agent.get('configured_mcps', [])
custom_mcps = agent.get('custom_mcps', [])
system_prompt = agent.get('system_prompt', '')
agentpress_tools = agent.get('agentpress_tools', {})
self._logger.info(f"Extracted data - configured_mcps: {len(configured_mcps)}, custom_mcps: {len(custom_mcps)}")
self._logger.debug(f"Agent keys: {list(agent.keys())}")
return {
'system_prompt': system_prompt,
'agentpress_tools': agentpress_tools,
'configured_mcps': configured_mcps,
'custom_mcps': custom_mcps
}
def _build_requirements(
self,
configured_mcps: List[ConfigType],
custom_mcps: List[ConfigType]
) -> List[MCPRequirementValue]:
requirements = []
self._logger.info(f"Building requirements from {len(configured_mcps)} configured MCPs and {len(custom_mcps)} custom MCPs")
for i, mcp in enumerate(configured_mcps):
self._logger.debug(f"Processing configured MCP {i}: {mcp}")
if isinstance(mcp, dict):
has_qualified_name = 'qualifiedName' in mcp or 'qualified_name' in mcp
if has_qualified_name:
try:
req = self._factory.from_configured_mcp(mcp)
requirements.append(req)
self._logger.info(f"Added configured requirement: {req.qualified_name}")
except Exception as e:
self._logger.error(f"Error processing configured MCP {i}: {e}")
self._logger.debug(f"Failed MCP data: {mcp}")
else:
self._logger.warning(f"Skipping configured MCP {i}: missing qualifiedName field")
else:
self._logger.warning(f"Skipping configured MCP {i}: not a dict")
for i, custom_mcp in enumerate(custom_mcps):
self._logger.debug(f"Processing custom MCP {i}: {custom_mcp}")
if isinstance(custom_mcp, dict) and 'name' in custom_mcp:
try:
req = self._factory.from_custom_mcp(custom_mcp)
requirements.append(req)
self._logger.info(f"Added custom requirement: {req.qualified_name}")
except Exception as e:
self._logger.error(f"Error processing custom MCP {i}: {e}")
self._logger.debug(f"Failed custom MCP data: {custom_mcp}")
else:
self._logger.warning(f"Skipping custom MCP {i}: missing name field or not dict")
self._logger.info(f"Built {len(requirements)} total MCP requirements")
return requirements