from datetime import datetime, timezone from typing import Dict, Any, Optional, List from uuid import uuid4 from ..domain.entities import AgentTemplate, ConfigType, MCPRequirementValue from ..domain.exceptions import TemplateNotFoundError, TemplateAccessDeniedError, SunaDefaultAgentTemplateError from ..repositories.template_repository import TemplateRepository from ..repositories.agent_repository import AgentRepository from ..support.validator import TemplateValidator from ..support.factory import MCPRequirementFactory from ..protocols import VersionManager, Logger class TemplateCreationService: def __init__( self, template_repo: TemplateRepository, agent_repo: AgentRepository, version_manager: Optional[VersionManager], validator: TemplateValidator, factory: MCPRequirementFactory, logger: Logger ): self._template_repo = template_repo self._agent_repo = agent_repo self._version_manager = version_manager self._validator = validator self._factory = factory self._logger = logger def _is_suna_default_agent(self, agent: Dict[str, Any]) -> bool: metadata = agent.get('metadata', {}) return metadata.get('is_suna_default', False) async def create_from_agent( self, agent_id: str, creator_id: str, make_public: bool = False, tags: Optional[List[str]] = None ) -> str: self._logger.info(f"Creating template from agent {agent_id} for user {creator_id}") agent = await self._agent_repo.find_by_id(agent_id) if not agent: raise TemplateNotFoundError("Agent not found") if agent['account_id'] != creator_id: raise TemplateAccessDeniedError("You can only create templates from your own agents") # Check if this is a Suna default agent if self._is_suna_default_agent(agent): raise SunaDefaultAgentTemplateError("Cannot create templates from the default Suna agent") version_data = await self._get_version_data(agent, creator_id) config = self._extract_config(agent, version_data) mcp_requirements = self._build_requirements( config['configured_mcps'], config['custom_mcps'] ) template = AgentTemplate( template_id=str(uuid4()), creator_id=creator_id, name=agent['name'], description=agent.get('description'), system_prompt=config['system_prompt'], mcp_requirements=mcp_requirements, agentpress_tools=config['agentpress_tools'], tags=tags or [], is_public=make_public, marketplace_published_at=datetime.now(timezone.utc) if make_public else None, avatar=agent.get('avatar'), avatar_color=agent.get('avatar_color'), metadata={ 'source_agent_id': agent_id, 'source_version_id': agent.get('current_version_id'), 'source_version_name': version_data.get('version_name', 'v1') if version_data else 'legacy' } ) saved_template = await self._template_repo.save(template) self._logger.info(f"Successfully created template {saved_template.template_id}") return saved_template.template_id async def _get_version_data(self, agent: Dict[str, Any], user_id: str) -> Optional[Dict[str, Any]]: if not self._version_manager or not agent.get('current_version_id'): return None try: return await self._version_manager.get_version( agent_id=agent['agent_id'], version_id=agent['current_version_id'], user_id=user_id ) except Exception as e: self._logger.warning(f"Failed to get version data: {e}") return None def _extract_config(self, agent: Dict[str, Any], version_data: Optional[Dict[str, Any]]) -> Dict[str, Any]: if version_data: self._logger.info(f"Using version config for template creation") configured_mcps = version_data.get('configured_mcps', []) custom_mcps = version_data.get('custom_mcps', []) self._logger.info(f"Version data - configured_mcps: {len(configured_mcps)}, custom_mcps: {len(custom_mcps)}") return { 'system_prompt': version_data.get('system_prompt', ''), 'agentpress_tools': version_data.get('agentpress_tools', {}), 'configured_mcps': configured_mcps, 'custom_mcps': custom_mcps } else: self._logger.info(f"Using legacy config for template creation") config = agent.get('config', {}) if config and config != {}: self._logger.info(f"Found unified config, extracting MCP data from it") tools = config.get('tools', {}) configured_mcps = tools.get('mcp', []) custom_mcps = tools.get('custom_mcp', []) system_prompt = config.get('system_prompt', agent.get('system_prompt', '')) agentpress_tools = tools.get('agentpress', agent.get('agentpress_tools', {})) else: self._logger.info(f"No unified config, using legacy columns") configured_mcps = agent.get('configured_mcps', []) custom_mcps = agent.get('custom_mcps', []) system_prompt = agent.get('system_prompt', '') agentpress_tools = agent.get('agentpress_tools', {}) self._logger.info(f"Extracted data - configured_mcps: {len(configured_mcps)}, custom_mcps: {len(custom_mcps)}") self._logger.debug(f"Agent keys: {list(agent.keys())}") return { 'system_prompt': system_prompt, 'agentpress_tools': agentpress_tools, 'configured_mcps': configured_mcps, 'custom_mcps': custom_mcps } def _build_requirements( self, configured_mcps: List[ConfigType], custom_mcps: List[ConfigType] ) -> List[MCPRequirementValue]: requirements = [] self._logger.info(f"Building requirements from {len(configured_mcps)} configured MCPs and {len(custom_mcps)} custom MCPs") for i, mcp in enumerate(configured_mcps): self._logger.debug(f"Processing configured MCP {i}: {mcp}") if isinstance(mcp, dict): has_qualified_name = 'qualifiedName' in mcp or 'qualified_name' in mcp if has_qualified_name: try: req = self._factory.from_configured_mcp(mcp) requirements.append(req) self._logger.info(f"Added configured requirement: {req.qualified_name}") except Exception as e: self._logger.error(f"Error processing configured MCP {i}: {e}") self._logger.debug(f"Failed MCP data: {mcp}") else: self._logger.warning(f"Skipping configured MCP {i}: missing qualifiedName field") else: self._logger.warning(f"Skipping configured MCP {i}: not a dict") for i, custom_mcp in enumerate(custom_mcps): self._logger.debug(f"Processing custom MCP {i}: {custom_mcp}") if isinstance(custom_mcp, dict) and 'name' in custom_mcp: try: req = self._factory.from_custom_mcp(custom_mcp) requirements.append(req) self._logger.info(f"Added custom requirement: {req.qualified_name}") except Exception as e: self._logger.error(f"Error processing custom MCP {i}: {e}") self._logger.debug(f"Failed custom MCP data: {custom_mcp}") else: self._logger.warning(f"Skipping custom MCP {i}: missing name field or not dict") self._logger.info(f"Built {len(requirements)} total MCP requirements") return requirements