import json from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Dict, List, Any, Optional from uuid import uuid4 from services.supabase import DBConnection from utils.logger import logger ConfigType = Dict[str, Any] ProfileId = str QualifiedName = str @dataclass(frozen=True) class MCPRequirementValue: qualified_name: str display_name: str enabled_tools: List[str] = field(default_factory=list) required_config: List[str] = field(default_factory=list) custom_type: Optional[str] = None def is_custom(self) -> bool: if self.qualified_name.startswith('pipedream:'): return False return self.custom_type is not None and self.qualified_name.startswith('custom_') @dataclass(frozen=True) class AgentTemplate: template_id: str creator_id: str name: str config: ConfigType description: Optional[str] = None tags: List[str] = field(default_factory=list) is_public: bool = False marketplace_published_at: Optional[datetime] = None download_count: int = 0 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) avatar: Optional[str] = None avatar_color: Optional[str] = None metadata: ConfigType = field(default_factory=dict) def with_public_status(self, is_public: bool, published_at: Optional[datetime] = None) -> 'AgentTemplate': return AgentTemplate( **{**self.__dict__, 'is_public': is_public, 'marketplace_published_at': published_at} ) @property def system_prompt(self) -> str: return self.config.get('system_prompt', '') @property def agentpress_tools(self) -> Dict[str, Any]: return self.config.get('tools', {}).get('agentpress', {}) @property def mcp_requirements(self) -> List[MCPRequirementValue]: requirements = [] mcps = self.config.get('tools', {}).get('mcp', []) for mcp in mcps: if isinstance(mcp, dict) and mcp.get('name'): qualified_name = mcp.get('qualifiedName', mcp['name']) requirements.append(MCPRequirementValue( qualified_name=qualified_name, display_name=mcp.get('display_name') or mcp['name'], enabled_tools=mcp.get('enabledTools', []), required_config=mcp.get('requiredConfig', []) )) custom_mcps = self.config.get('tools', {}).get('custom_mcp', []) for mcp in custom_mcps: if isinstance(mcp, dict) and mcp.get('name'): mcp_type = mcp.get('type', 'sse') mcp_name = mcp['name'] if mcp_type == 'pipedream': app_slug = mcp.get('config', {}).get('headers', {}).get('x-pd-app-slug') if not app_slug: app_slug = mcp_name.lower().replace(' ', '').replace('(', '').replace(')', '') qualified_name = f"pipedream:{app_slug}" required_config = [] else: safe_name = mcp_name.replace(' ', '_').lower() qualified_name = f"custom_{mcp_type}_{safe_name}" if mcp_type in ['http', 'sse', 'json']: required_config = ['url'] else: required_config = mcp.get('requiredConfig', ['url']) requirements.append(MCPRequirementValue( qualified_name=qualified_name, display_name=mcp.get('display_name') or mcp_name, enabled_tools=mcp.get('enabledTools', []), required_config=required_config, custom_type=mcp_type )) return requirements @dataclass class TemplateCreationRequest: agent_id: str creator_id: str make_public: bool = False tags: Optional[List[str]] = None class TemplateNotFoundError(Exception): pass class TemplateAccessDeniedError(Exception): pass class SunaDefaultAgentTemplateError(Exception): pass class TemplateService: def __init__(self, db_connection: DBConnection): self._db = db_connection async def create_from_agent( self, agent_id: str, creator_id: str, make_public: bool = False, tags: Optional[List[str]] = None ) -> str: logger.info(f"Creating template from agent {agent_id} for user {creator_id}") agent = await self._get_agent_by_id(agent_id) if not agent: raise TemplateNotFoundError("Agent not found") if agent['account_id'] != creator_id: raise TemplateAccessDeniedError("You can only create templates from your own agents") if self._is_suna_default_agent(agent): raise SunaDefaultAgentTemplateError("Cannot create template from Suna default agent") version_config = await self._get_agent_version_config(agent) if not version_config: raise TemplateNotFoundError("Agent has no version configuration") sanitized_config = await self._sanitize_config_for_template(version_config) template = AgentTemplate( template_id=str(uuid4()), creator_id=creator_id, name=agent['name'], description=agent.get('description'), config=sanitized_config, tags=tags or [], is_public=make_public, marketplace_published_at=datetime.now(timezone.utc) if make_public else None, avatar=agent.get('avatar'), avatar_color=agent.get('avatar_color'), metadata=agent.get('metadata', {}) ) await self._save_template(template) logger.info(f"Created template {template.template_id} from agent {agent_id}") return template.template_id async def get_template(self, template_id: str) -> Optional[AgentTemplate]: client = await self._db.client result = await client.table('agent_templates').select('*')\ .eq('template_id', template_id)\ .maybe_single()\ .execute() if not result.data: return None return self._map_to_template(result.data) async def get_user_templates(self, creator_id: str) -> List[AgentTemplate]: client = await self._db.client result = await client.table('agent_templates').select('*')\ .eq('creator_id', creator_id)\ .order('created_at', desc=True)\ .execute() return [self._map_to_template(data) for data in result.data] async def get_public_templates(self) -> List[AgentTemplate]: client = await self._db.client result = await client.table('agent_templates').select('*')\ .eq('is_public', True)\ .order('download_count', desc=True)\ .order('marketplace_published_at', desc=True)\ .execute() return [self._map_to_template(data) for data in result.data] async def publish_template(self, template_id: str, creator_id: str) -> bool: logger.info(f"Publishing template {template_id}") client = await self._db.client result = await client.table('agent_templates').update({ 'is_public': True, 'marketplace_published_at': datetime.now(timezone.utc).isoformat(), 'updated_at': datetime.now(timezone.utc).isoformat() }).eq('template_id', template_id)\ .eq('creator_id', creator_id)\ .execute() success = len(result.data) > 0 if success: logger.info(f"Published template {template_id}") return success async def unpublish_template(self, template_id: str, creator_id: str) -> bool: logger.info(f"Unpublishing template {template_id}") client = await self._db.client result = await client.table('agent_templates').update({ 'is_public': False, 'marketplace_published_at': None, 'updated_at': datetime.now(timezone.utc).isoformat() }).eq('template_id', template_id)\ .eq('creator_id', creator_id)\ .execute() success = len(result.data) > 0 if success: logger.info(f"Unpublished template {template_id}") return success async def increment_download_count(self, template_id: str) -> None: client = await self._db.client await client.rpc('increment_template_download_count', { 'template_id_param': template_id }).execute() async def validate_access(self, template: AgentTemplate, user_id: str) -> None: if template.creator_id != user_id and not template.is_public: raise TemplateAccessDeniedError("Access denied to template") async def _get_agent_by_id(self, agent_id: str) -> Optional[Dict[str, Any]]: client = await self._db.client result = await client.table('agents').select('*')\ .eq('agent_id', agent_id)\ .maybe_single()\ .execute() return result.data async def _get_agent_version_config(self, agent: Dict[str, Any]) -> Optional[Dict[str, Any]]: version_id = agent.get('current_version_id') if version_id: client = await self._db.client result = await client.table('agent_versions').select('config')\ .eq('version_id', version_id)\ .maybe_single()\ .execute() if result.data and result.data['config']: return result.data['config'] return {} async def _sanitize_config_for_template(self, config: Dict[str, Any]) -> Dict[str, Any]: client = await self._db.client result = await client.rpc('sanitize_config_for_template', { 'input_config': config }).execute() if result.data: return result.data return self._fallback_sanitize_config(config) def _fallback_sanitize_config(self, config: Dict[str, Any]) -> Dict[str, Any]: sanitized = { 'system_prompt': config.get('system_prompt', ''), 'tools': { 'agentpress': config.get('tools', {}).get('agentpress', {}), 'mcp': config.get('tools', {}).get('mcp', []), 'custom_mcp': [] }, 'metadata': { 'avatar': config.get('metadata', {}).get('avatar'), 'avatar_color': config.get('metadata', {}).get('avatar_color') } } custom_mcps = config.get('tools', {}).get('custom_mcp', []) for mcp in custom_mcps: if isinstance(mcp, dict): sanitized_mcp = { 'name': mcp.get('name'), 'type': mcp.get('type'), 'display_name': mcp.get('display_name') or mcp.get('name'), 'enabledTools': mcp.get('enabledTools', []) } if mcp.get('type') == 'pipedream': original_config = mcp.get('config', {}) sanitized_mcp['config'] = { 'url': original_config.get('url'), 'headers': {k: v for k, v in original_config.get('headers', {}).items() if k != 'profile_id'} } else: sanitized_mcp['config'] = {} sanitized['tools']['custom_mcp'].append(sanitized_mcp) return sanitized def _is_suna_default_agent(self, agent: Dict[str, Any]) -> bool: metadata = agent.get('metadata', {}) return metadata.get('is_suna_default', False) async def _save_template(self, template: AgentTemplate) -> None: client = await self._db.client template_data = { 'template_id': template.template_id, 'creator_id': template.creator_id, 'name': template.name, 'description': template.description, 'config': template.config, 'tags': template.tags, 'is_public': template.is_public, 'marketplace_published_at': template.marketplace_published_at.isoformat() if template.marketplace_published_at else None, 'download_count': template.download_count, 'created_at': template.created_at.isoformat(), 'updated_at': template.updated_at.isoformat(), 'avatar': template.avatar, 'avatar_color': template.avatar_color, 'metadata': template.metadata } await client.table('agent_templates').insert(template_data).execute() def _map_to_template(self, data: Dict[str, Any]) -> AgentTemplate: return AgentTemplate( template_id=data['template_id'], creator_id=data['creator_id'], name=data['name'], description=data.get('description'), config=data.get('config', {}), tags=data.get('tags', []), is_public=data.get('is_public', False), marketplace_published_at=datetime.fromisoformat(data['marketplace_published_at'].replace('Z', '+00:00')) if data.get('marketplace_published_at') else None, download_count=data.get('download_count', 0), created_at=datetime.fromisoformat(data['created_at'].replace('Z', '+00:00')), updated_at=datetime.fromisoformat(data['updated_at'].replace('Z', '+00:00')), avatar=data.get('avatar'), avatar_color=data.get('avatar_color'), metadata=data.get('metadata', {}) ) def get_template_service(db_connection: DBConnection) -> TemplateService: return TemplateService(db_connection)