suna/backend/templates/template_service.py

364 lines
14 KiB
Python

import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional
from uuid import uuid4
from services.supabase import DBConnection
from utils.logger import logger
ConfigType = Dict[str, Any]
ProfileId = str
QualifiedName = str
@dataclass(frozen=True)
class MCPRequirementValue:
qualified_name: str
display_name: str
enabled_tools: List[str] = field(default_factory=list)
required_config: List[str] = field(default_factory=list)
custom_type: Optional[str] = None
def is_custom(self) -> bool:
if self.qualified_name.startswith('pipedream:'):
return False
return self.custom_type is not None and self.qualified_name.startswith('custom_')
@dataclass(frozen=True)
class AgentTemplate:
template_id: str
creator_id: str
name: str
config: ConfigType
description: Optional[str] = None
tags: List[str] = field(default_factory=list)
is_public: bool = False
marketplace_published_at: Optional[datetime] = None
download_count: int = 0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
avatar: Optional[str] = None
avatar_color: Optional[str] = None
metadata: ConfigType = field(default_factory=dict)
def with_public_status(self, is_public: bool, published_at: Optional[datetime] = None) -> 'AgentTemplate':
return AgentTemplate(
**{**self.__dict__,
'is_public': is_public,
'marketplace_published_at': published_at}
)
@property
def system_prompt(self) -> str:
return self.config.get('system_prompt', '')
@property
def agentpress_tools(self) -> Dict[str, Any]:
return self.config.get('tools', {}).get('agentpress', {})
@property
def mcp_requirements(self) -> List[MCPRequirementValue]:
requirements = []
mcps = self.config.get('tools', {}).get('mcp', [])
for mcp in mcps:
if isinstance(mcp, dict) and mcp.get('name'):
qualified_name = mcp.get('qualifiedName', mcp['name'])
requirements.append(MCPRequirementValue(
qualified_name=qualified_name,
display_name=mcp.get('display_name') or mcp['name'],
enabled_tools=mcp.get('enabledTools', []),
required_config=mcp.get('requiredConfig', [])
))
custom_mcps = self.config.get('tools', {}).get('custom_mcp', [])
for mcp in custom_mcps:
if isinstance(mcp, dict) and mcp.get('name'):
mcp_type = mcp.get('type', 'sse')
mcp_name = mcp['name']
if mcp_type == 'pipedream':
app_slug = mcp.get('config', {}).get('headers', {}).get('x-pd-app-slug')
if not app_slug:
app_slug = mcp_name.lower().replace(' ', '').replace('(', '').replace(')', '')
qualified_name = f"pipedream:{app_slug}"
required_config = []
else:
safe_name = mcp_name.replace(' ', '_').lower()
qualified_name = f"custom_{mcp_type}_{safe_name}"
if mcp_type in ['http', 'sse', 'json']:
required_config = ['url']
else:
required_config = mcp.get('requiredConfig', ['url'])
requirements.append(MCPRequirementValue(
qualified_name=qualified_name,
display_name=mcp.get('display_name') or mcp_name,
enabled_tools=mcp.get('enabledTools', []),
required_config=required_config,
custom_type=mcp_type
))
return requirements
@dataclass
class TemplateCreationRequest:
agent_id: str
creator_id: str
make_public: bool = False
tags: Optional[List[str]] = None
class TemplateNotFoundError(Exception):
pass
class TemplateAccessDeniedError(Exception):
pass
class SunaDefaultAgentTemplateError(Exception):
pass
class TemplateService:
def __init__(self, db_connection: DBConnection):
self._db = db_connection
async def create_from_agent(
self,
agent_id: str,
creator_id: str,
make_public: bool = False,
tags: Optional[List[str]] = None
) -> str:
logger.info(f"Creating template from agent {agent_id} for user {creator_id}")
agent = await self._get_agent_by_id(agent_id)
if not agent:
raise TemplateNotFoundError("Agent not found")
if agent['account_id'] != creator_id:
raise TemplateAccessDeniedError("You can only create templates from your own agents")
if self._is_suna_default_agent(agent):
raise SunaDefaultAgentTemplateError("Cannot create template from Suna default agent")
version_config = await self._get_agent_version_config(agent)
if not version_config:
raise TemplateNotFoundError("Agent has no version configuration")
sanitized_config = await self._sanitize_config_for_template(version_config)
template = AgentTemplate(
template_id=str(uuid4()),
creator_id=creator_id,
name=agent['name'],
description=agent.get('description'),
config=sanitized_config,
tags=tags or [],
is_public=make_public,
marketplace_published_at=datetime.now(timezone.utc) if make_public else None,
avatar=agent.get('avatar'),
avatar_color=agent.get('avatar_color'),
metadata=agent.get('metadata', {})
)
await self._save_template(template)
logger.info(f"Created template {template.template_id} from agent {agent_id}")
return template.template_id
async def get_template(self, template_id: str) -> Optional[AgentTemplate]:
client = await self._db.client
result = await client.table('agent_templates').select('*')\
.eq('template_id', template_id)\
.maybe_single()\
.execute()
if not result.data:
return None
return self._map_to_template(result.data)
async def get_user_templates(self, creator_id: str) -> List[AgentTemplate]:
client = await self._db.client
result = await client.table('agent_templates').select('*')\
.eq('creator_id', creator_id)\
.order('created_at', desc=True)\
.execute()
return [self._map_to_template(data) for data in result.data]
async def get_public_templates(self) -> List[AgentTemplate]:
client = await self._db.client
result = await client.table('agent_templates').select('*')\
.eq('is_public', True)\
.order('download_count', desc=True)\
.order('marketplace_published_at', desc=True)\
.execute()
return [self._map_to_template(data) for data in result.data]
async def publish_template(self, template_id: str, creator_id: str) -> bool:
logger.info(f"Publishing template {template_id}")
client = await self._db.client
result = await client.table('agent_templates').update({
'is_public': True,
'marketplace_published_at': datetime.now(timezone.utc).isoformat(),
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('template_id', template_id)\
.eq('creator_id', creator_id)\
.execute()
success = len(result.data) > 0
if success:
logger.info(f"Published template {template_id}")
return success
async def unpublish_template(self, template_id: str, creator_id: str) -> bool:
logger.info(f"Unpublishing template {template_id}")
client = await self._db.client
result = await client.table('agent_templates').update({
'is_public': False,
'marketplace_published_at': None,
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('template_id', template_id)\
.eq('creator_id', creator_id)\
.execute()
success = len(result.data) > 0
if success:
logger.info(f"Unpublished template {template_id}")
return success
async def increment_download_count(self, template_id: str) -> None:
client = await self._db.client
await client.rpc('increment_template_download_count', {
'template_id_param': template_id
}).execute()
async def validate_access(self, template: AgentTemplate, user_id: str) -> None:
if template.creator_id != user_id and not template.is_public:
raise TemplateAccessDeniedError("Access denied to template")
async def _get_agent_by_id(self, agent_id: str) -> Optional[Dict[str, Any]]:
client = await self._db.client
result = await client.table('agents').select('*')\
.eq('agent_id', agent_id)\
.maybe_single()\
.execute()
return result.data
async def _get_agent_version_config(self, agent: Dict[str, Any]) -> Optional[Dict[str, Any]]:
version_id = agent.get('current_version_id')
if version_id:
client = await self._db.client
result = await client.table('agent_versions').select('config')\
.eq('version_id', version_id)\
.maybe_single()\
.execute()
if result.data and result.data['config']:
return result.data['config']
return {}
async def _sanitize_config_for_template(self, config: Dict[str, Any]) -> Dict[str, Any]:
client = await self._db.client
result = await client.rpc('sanitize_config_for_template', {
'input_config': config
}).execute()
if result.data:
return result.data
return self._fallback_sanitize_config(config)
def _fallback_sanitize_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
sanitized = {
'system_prompt': config.get('system_prompt', ''),
'tools': {
'agentpress': config.get('tools', {}).get('agentpress', {}),
'mcp': config.get('tools', {}).get('mcp', []),
'custom_mcp': []
},
'metadata': {
'avatar': config.get('metadata', {}).get('avatar'),
'avatar_color': config.get('metadata', {}).get('avatar_color')
}
}
custom_mcps = config.get('tools', {}).get('custom_mcp', [])
for mcp in custom_mcps:
if isinstance(mcp, dict):
sanitized_mcp = {
'name': mcp.get('name'),
'type': mcp.get('type'),
'display_name': mcp.get('display_name') or mcp.get('name'),
'enabledTools': mcp.get('enabledTools', [])
}
if mcp.get('type') == 'pipedream':
original_config = mcp.get('config', {})
sanitized_mcp['config'] = {
'url': original_config.get('url'),
'headers': {k: v for k, v in original_config.get('headers', {}).items()
if k != 'profile_id'}
}
else:
sanitized_mcp['config'] = {}
sanitized['tools']['custom_mcp'].append(sanitized_mcp)
return sanitized
def _is_suna_default_agent(self, agent: Dict[str, Any]) -> bool:
metadata = agent.get('metadata', {})
return metadata.get('is_suna_default', False)
async def _save_template(self, template: AgentTemplate) -> None:
client = await self._db.client
template_data = {
'template_id': template.template_id,
'creator_id': template.creator_id,
'name': template.name,
'description': template.description,
'config': template.config,
'tags': template.tags,
'is_public': template.is_public,
'marketplace_published_at': template.marketplace_published_at.isoformat() if template.marketplace_published_at else None,
'download_count': template.download_count,
'created_at': template.created_at.isoformat(),
'updated_at': template.updated_at.isoformat(),
'avatar': template.avatar,
'avatar_color': template.avatar_color,
'metadata': template.metadata
}
await client.table('agent_templates').insert(template_data).execute()
def _map_to_template(self, data: Dict[str, Any]) -> AgentTemplate:
return AgentTemplate(
template_id=data['template_id'],
creator_id=data['creator_id'],
name=data['name'],
description=data.get('description'),
config=data.get('config', {}),
tags=data.get('tags', []),
is_public=data.get('is_public', False),
marketplace_published_at=datetime.fromisoformat(data['marketplace_published_at'].replace('Z', '+00:00')) if data.get('marketplace_published_at') else None,
download_count=data.get('download_count', 0),
created_at=datetime.fromisoformat(data['created_at'].replace('Z', '+00:00')),
updated_at=datetime.fromisoformat(data['updated_at'].replace('Z', '+00:00')),
avatar=data.get('avatar'),
avatar_color=data.get('avatar_color'),
metadata=data.get('metadata', {})
)
def get_template_service(db_connection: DBConnection) -> TemplateService:
return TemplateService(db_connection)