mirror of https://github.com/kortix-ai/suna.git
768 lines
31 KiB
Python
768 lines
31 KiB
Python
import json
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, List, Any, Optional
|
|
from uuid import uuid4
|
|
import secrets
|
|
import string
|
|
|
|
from services.supabase import DBConnection
|
|
from utils.logger import logger
|
|
|
|
ConfigType = Dict[str, Any]
|
|
ProfileId = str
|
|
QualifiedName = str
|
|
|
|
@dataclass(frozen=True)
|
|
class MCPRequirementValue:
|
|
qualified_name: str
|
|
display_name: str
|
|
enabled_tools: List[str] = field(default_factory=list)
|
|
required_config: List[str] = field(default_factory=list)
|
|
custom_type: Optional[str] = None
|
|
toolkit_slug: Optional[str] = None
|
|
app_slug: Optional[str] = None
|
|
source: Optional[str] = None
|
|
trigger_index: Optional[int] = None
|
|
|
|
def is_custom(self) -> bool:
|
|
if self.qualified_name.startswith('pipedream:'):
|
|
return False
|
|
if self.custom_type == 'composio' or self.qualified_name.startswith('composio.'):
|
|
return False
|
|
return self.custom_type is not None and self.qualified_name.startswith('custom_')
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentTemplate:
|
|
template_id: str
|
|
creator_id: str
|
|
name: str
|
|
config: ConfigType
|
|
description: Optional[str] = None
|
|
tags: List[str] = field(default_factory=list)
|
|
is_public: bool = False
|
|
is_kortix_team: bool = False
|
|
marketplace_published_at: Optional[datetime] = None
|
|
download_count: int = 0
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
avatar: Optional[str] = None
|
|
avatar_color: Optional[str] = None
|
|
profile_image_url: Optional[str] = None
|
|
icon_name: Optional[str] = None
|
|
icon_color: Optional[str] = None
|
|
icon_background: Optional[str] = None
|
|
metadata: ConfigType = field(default_factory=dict)
|
|
creator_name: Optional[str] = None
|
|
|
|
def with_public_status(self, is_public: bool, published_at: Optional[datetime] = None) -> 'AgentTemplate':
|
|
return AgentTemplate(
|
|
**{**self.__dict__,
|
|
'is_public': is_public,
|
|
'marketplace_published_at': published_at}
|
|
)
|
|
|
|
@property
|
|
def system_prompt(self) -> str:
|
|
return self.config.get('system_prompt', '')
|
|
|
|
@property
|
|
def agentpress_tools(self) -> Dict[str, Any]:
|
|
return self.config.get('tools', {}).get('agentpress', {})
|
|
|
|
@property
|
|
def workflows(self) -> List[Dict[str, Any]]:
|
|
return self.config.get('workflows', [])
|
|
|
|
@property
|
|
def mcp_requirements(self) -> List[MCPRequirementValue]:
|
|
requirements = []
|
|
|
|
mcps = self.config.get('tools', {}).get('mcp', [])
|
|
for mcp in mcps:
|
|
if isinstance(mcp, dict) and mcp.get('name'):
|
|
qualified_name = mcp.get('qualifiedName', mcp['name'])
|
|
|
|
requirements.append(MCPRequirementValue(
|
|
qualified_name=qualified_name,
|
|
display_name=mcp.get('display_name') or mcp['name'],
|
|
enabled_tools=mcp.get('enabledTools', []),
|
|
required_config=mcp.get('requiredConfig', []),
|
|
source='tool'
|
|
))
|
|
|
|
custom_mcps = self.config.get('tools', {}).get('custom_mcp', [])
|
|
for mcp in custom_mcps:
|
|
if isinstance(mcp, dict) and mcp.get('name'):
|
|
mcp_type = mcp.get('type', 'sse')
|
|
mcp_name = mcp['name']
|
|
|
|
qualified_name = mcp.get('mcp_qualified_name') or mcp.get('qualifiedName')
|
|
if not qualified_name:
|
|
if mcp_type == 'pipedream':
|
|
app_slug = mcp.get('app_slug') or mcp.get('config', {}).get('headers', {}).get('x-pd-app-slug')
|
|
if not app_slug:
|
|
app_slug = mcp_name.lower().replace(' ', '').replace('(', '').replace(')', '')
|
|
qualified_name = f"pipedream:{app_slug}"
|
|
elif mcp_type == 'composio':
|
|
toolkit_slug = mcp.get('toolkit_slug') or mcp_name.lower().replace(' ', '_')
|
|
qualified_name = f"composio.{toolkit_slug}"
|
|
else:
|
|
safe_name = mcp_name.replace(' ', '_').lower()
|
|
qualified_name = f"custom_{mcp_type}_{safe_name}"
|
|
|
|
if mcp_type in ['pipedream', 'composio']:
|
|
required_config = []
|
|
elif mcp_type in ['http', 'sse', 'json']:
|
|
required_config = ['url']
|
|
else:
|
|
required_config = mcp.get('requiredConfig', ['url'])
|
|
|
|
requirements.append(MCPRequirementValue(
|
|
qualified_name=qualified_name,
|
|
display_name=mcp.get('display_name') or mcp_name,
|
|
enabled_tools=mcp.get('enabledTools', []),
|
|
required_config=required_config,
|
|
custom_type=mcp_type,
|
|
toolkit_slug=mcp.get('toolkit_slug') if mcp_type == 'composio' else None,
|
|
app_slug=mcp.get('app_slug') if mcp_type == 'pipedream' else None,
|
|
source='tool'
|
|
))
|
|
|
|
triggers = self.config.get('triggers', [])
|
|
|
|
for i, trigger in enumerate(triggers):
|
|
config = trigger.get('config', {})
|
|
provider_id = config.get('provider_id', '')
|
|
|
|
if provider_id == 'composio':
|
|
qualified_name = config.get('qualified_name')
|
|
|
|
if not qualified_name:
|
|
trigger_slug = config.get('trigger_slug', '')
|
|
if trigger_slug:
|
|
app_name = trigger_slug.split('_')[0].lower() if '_' in trigger_slug else 'composio'
|
|
qualified_name = f'composio.{app_name}'
|
|
else:
|
|
qualified_name = 'composio'
|
|
|
|
if qualified_name:
|
|
if qualified_name.startswith('composio.'):
|
|
app_name = qualified_name.split('.', 1)[1]
|
|
else:
|
|
app_name = 'composio'
|
|
|
|
trigger_name = trigger.get('name', f'Trigger {i+1}')
|
|
|
|
composio_req = MCPRequirementValue(
|
|
qualified_name=qualified_name,
|
|
display_name=f"{app_name.title()} ({trigger_name})",
|
|
enabled_tools=[],
|
|
required_config=[],
|
|
custom_type=None,
|
|
toolkit_slug=app_name,
|
|
app_slug=app_name,
|
|
source='trigger',
|
|
trigger_index=i
|
|
)
|
|
requirements.append(composio_req)
|
|
|
|
return requirements
|
|
|
|
@dataclass
|
|
class TemplateCreationRequest:
|
|
agent_id: str
|
|
creator_id: str
|
|
make_public: bool = False
|
|
tags: Optional[List[str]] = None
|
|
|
|
class TemplateNotFoundError(Exception):
|
|
pass
|
|
|
|
class TemplateAccessDeniedError(Exception):
|
|
pass
|
|
|
|
class SunaDefaultAgentTemplateError(Exception):
|
|
pass
|
|
|
|
class TemplateService:
|
|
def __init__(self, db_connection: DBConnection):
|
|
self._db = db_connection
|
|
|
|
async def create_from_agent(
|
|
self,
|
|
agent_id: str,
|
|
creator_id: str,
|
|
make_public: bool = False,
|
|
tags: Optional[List[str]] = None
|
|
) -> str:
|
|
logger.debug(f"Creating template from agent {agent_id} for user {creator_id}")
|
|
|
|
agent = await self._get_agent_by_id(agent_id)
|
|
if not agent:
|
|
raise TemplateNotFoundError("Agent not found")
|
|
|
|
if agent['account_id'] != creator_id:
|
|
raise TemplateAccessDeniedError("You can only create templates from your own agents")
|
|
|
|
if self._is_suna_default_agent(agent):
|
|
raise SunaDefaultAgentTemplateError("Cannot create template from Suna default agent")
|
|
|
|
version_config = await self._get_agent_version_config(agent)
|
|
if not version_config:
|
|
raise TemplateNotFoundError("Agent has no version configuration")
|
|
|
|
sanitized_config = await self._sanitize_config_for_template(version_config)
|
|
|
|
template = AgentTemplate(
|
|
template_id=str(uuid4()),
|
|
creator_id=creator_id,
|
|
name=agent['name'],
|
|
description=agent.get('description'),
|
|
config=sanitized_config,
|
|
tags=tags or [],
|
|
is_public=make_public,
|
|
marketplace_published_at=datetime.now(timezone.utc) if make_public else None,
|
|
avatar=agent.get('avatar'),
|
|
avatar_color=agent.get('avatar_color'),
|
|
profile_image_url=agent.get('profile_image_url'),
|
|
icon_name=agent.get('icon_name'),
|
|
icon_color=agent.get('icon_color'),
|
|
icon_background=agent.get('icon_background'),
|
|
metadata=agent.get('metadata', {})
|
|
)
|
|
|
|
await self._save_template(template)
|
|
|
|
logger.debug(f"Created template {template.template_id} from agent {agent_id}")
|
|
return template.template_id
|
|
|
|
async def get_template(self, template_id: str) -> Optional[AgentTemplate]:
|
|
client = await self._db.client
|
|
result = await client.table('agent_templates').select('*')\
|
|
.eq('template_id', template_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
if not result.data:
|
|
return None
|
|
|
|
creator_id = result.data['creator_id']
|
|
creator_result = await client.schema('basejump').from_('accounts').select('id, name, slug').eq('id', creator_id).execute()
|
|
|
|
creator_name = None
|
|
if creator_result.data:
|
|
account = creator_result.data[0]
|
|
creator_name = account.get('name') or account.get('slug')
|
|
|
|
result.data['creator_name'] = creator_name
|
|
return self._map_to_template(result.data)
|
|
|
|
async def get_user_templates(self, creator_id: str) -> List[AgentTemplate]:
|
|
client = await self._db.client
|
|
result = await client.table('agent_templates').select('*')\
|
|
.eq('creator_id', creator_id)\
|
|
.order('created_at', desc=True)\
|
|
.execute()
|
|
|
|
if not result.data:
|
|
return []
|
|
|
|
creator_result = await client.schema('basejump').from_('accounts').select('id, name, slug').eq('id', creator_id).execute()
|
|
|
|
creator_name = None
|
|
if creator_result.data:
|
|
account = creator_result.data[0]
|
|
creator_name = account.get('name') or account.get('slug')
|
|
|
|
templates = []
|
|
for template_data in result.data:
|
|
template_data['creator_name'] = creator_name
|
|
templates.append(self._map_to_template(template_data))
|
|
|
|
return templates
|
|
|
|
async def get_public_templates(
|
|
self,
|
|
is_kortix_team: Optional[bool] = None,
|
|
limit: Optional[int] = None,
|
|
offset: int = 0,
|
|
search: Optional[str] = None,
|
|
tags: Optional[List[str]] = None
|
|
) -> List[AgentTemplate]:
|
|
client = await self._db.client
|
|
|
|
query = client.table('agent_templates').select('*').eq('is_public', True)
|
|
|
|
if is_kortix_team is not None:
|
|
query = query.eq('is_kortix_team', is_kortix_team)
|
|
|
|
if search:
|
|
query = query.or_(f"name.ilike.%{search}%,description.ilike.%{search}%")
|
|
|
|
if tags:
|
|
for tag in tags:
|
|
query = query.contains('tags', [tag])
|
|
|
|
query = query.order('download_count', desc=True)\
|
|
.order('marketplace_published_at', desc=True)
|
|
|
|
if limit:
|
|
query = query.limit(limit)
|
|
if offset:
|
|
query = query.offset(offset)
|
|
|
|
result = await query.execute()
|
|
|
|
if not result.data:
|
|
return []
|
|
|
|
creator_ids = list(set(template['creator_id'] for template in result.data))
|
|
|
|
from utils.query_utils import batch_query_in
|
|
|
|
accounts_data = await batch_query_in(
|
|
client=client,
|
|
table_name='accounts',
|
|
select_fields='id, name, slug',
|
|
in_field='id',
|
|
in_values=creator_ids,
|
|
schema='basejump'
|
|
)
|
|
|
|
creator_names = {}
|
|
for account in accounts_data:
|
|
creator_names[account['id']] = account.get('name') or account.get('slug')
|
|
|
|
templates = []
|
|
for template_data in result.data:
|
|
creator_name = creator_names.get(template_data['creator_id'])
|
|
template_data['creator_name'] = creator_name
|
|
templates.append(self._map_to_template(template_data))
|
|
|
|
return templates
|
|
|
|
async def publish_template(self, template_id: str, creator_id: str) -> bool:
|
|
logger.debug(f"Publishing template {template_id}")
|
|
|
|
client = await self._db.client
|
|
result = await client.table('agent_templates').update({
|
|
'is_public': True,
|
|
'marketplace_published_at': datetime.now(timezone.utc).isoformat(),
|
|
'updated_at': datetime.now(timezone.utc).isoformat()
|
|
}).eq('template_id', template_id)\
|
|
.eq('creator_id', creator_id)\
|
|
.execute()
|
|
|
|
success = len(result.data) > 0
|
|
if success:
|
|
logger.debug(f"Published template {template_id}")
|
|
|
|
return success
|
|
|
|
async def unpublish_template(self, template_id: str, creator_id: str) -> bool:
|
|
logger.debug(f"Unpublishing template {template_id}")
|
|
|
|
client = await self._db.client
|
|
result = await client.table('agent_templates').update({
|
|
'is_public': False,
|
|
'marketplace_published_at': None,
|
|
'updated_at': datetime.now(timezone.utc).isoformat()
|
|
}).eq('template_id', template_id)\
|
|
.eq('creator_id', creator_id)\
|
|
.execute()
|
|
|
|
success = len(result.data) > 0
|
|
if success:
|
|
logger.debug(f"Unpublished template {template_id}")
|
|
|
|
return success
|
|
|
|
async def delete_template(self, template_id: str, creator_id: str) -> bool:
|
|
"""Delete a template. Only the creator can delete their templates."""
|
|
logger.debug(f"Deleting template {template_id} for user {creator_id}")
|
|
|
|
client = await self._db.client
|
|
|
|
# First check if template exists and user owns it
|
|
template_result = await client.table('agent_templates').select('*')\
|
|
.eq('template_id', template_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
if not template_result.data:
|
|
logger.warning(f"Template {template_id} not found")
|
|
return False
|
|
|
|
template = template_result.data
|
|
if template['creator_id'] != creator_id:
|
|
logger.warning(f"User {creator_id} cannot delete template {template_id} (owned by {template['creator_id']})")
|
|
return False
|
|
|
|
# Delete the template
|
|
result = await client.table('agent_templates').delete()\
|
|
.eq('template_id', template_id)\
|
|
.eq('creator_id', creator_id)\
|
|
.execute()
|
|
|
|
success = len(result.data) > 0
|
|
if success:
|
|
logger.debug(f"Successfully deleted template {template_id}")
|
|
|
|
return success
|
|
|
|
async def increment_download_count(self, template_id: str) -> None:
|
|
client = await self._db.client
|
|
await client.rpc('increment_template_download_count', {
|
|
'template_id_param': template_id
|
|
}).execute()
|
|
|
|
async def validate_access(self, template: AgentTemplate, user_id: str) -> None:
|
|
if template.creator_id != user_id and not template.is_public:
|
|
raise TemplateAccessDeniedError("Access denied to template")
|
|
|
|
async def _get_agent_by_id(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
|
client = await self._db.client
|
|
result = await client.table('agents').select('*')\
|
|
.eq('agent_id', agent_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
return result.data
|
|
|
|
async def _get_agent_version_config(self, agent: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
version_id = agent.get('current_version_id')
|
|
if version_id:
|
|
client = await self._db.client
|
|
result = await client.table('agent_versions').select('config')\
|
|
.eq('version_id', version_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
if result.data and result.data['config']:
|
|
return result.data['config']
|
|
|
|
return {}
|
|
|
|
async def _sanitize_config_for_template(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
return self._fallback_sanitize_config(config)
|
|
|
|
def _fallback_sanitize_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
agentpress_tools = config.get('tools', {}).get('agentpress', {})
|
|
sanitized_agentpress = {}
|
|
|
|
for tool_name, tool_config in agentpress_tools.items():
|
|
if isinstance(tool_config, dict):
|
|
sanitized_agentpress[tool_name] = tool_config.get('enabled', False)
|
|
elif isinstance(tool_config, bool):
|
|
sanitized_agentpress[tool_name] = tool_config
|
|
else:
|
|
sanitized_agentpress[tool_name] = False
|
|
|
|
workflows = config.get('workflows', [])
|
|
sanitized_workflows = []
|
|
for workflow in workflows:
|
|
if isinstance(workflow, dict):
|
|
sanitized_workflow = {
|
|
'name': workflow.get('name'),
|
|
'description': workflow.get('description'),
|
|
'status': workflow.get('status', 'draft'),
|
|
'trigger_phrase': workflow.get('trigger_phrase'),
|
|
'is_default': workflow.get('is_default', False),
|
|
'steps': workflow.get('steps', [])
|
|
}
|
|
sanitized_workflows.append(sanitized_workflow)
|
|
|
|
triggers = config.get('triggers', [])
|
|
sanitized_triggers = []
|
|
for trigger in triggers:
|
|
if isinstance(trigger, dict):
|
|
trigger_config = trigger.get('config', {})
|
|
provider_id = trigger_config.get('provider_id', '')
|
|
|
|
sanitized_config = {
|
|
'provider_id': provider_id,
|
|
'agent_prompt': trigger_config.get('agent_prompt', ''),
|
|
'execution_type': trigger_config.get('execution_type', 'agent')
|
|
}
|
|
|
|
if sanitized_config['execution_type'] == 'workflow':
|
|
workflow_id = trigger_config.get('workflow_id')
|
|
if workflow_id:
|
|
workflow_name = None
|
|
for workflow in workflows:
|
|
if workflow.get('id') == workflow_id:
|
|
workflow_name = workflow.get('name')
|
|
break
|
|
if workflow_name:
|
|
sanitized_config['workflow_name'] = workflow_name
|
|
|
|
if 'workflow_input' in trigger_config:
|
|
sanitized_config['workflow_input'] = trigger_config['workflow_input']
|
|
|
|
if provider_id == 'schedule':
|
|
sanitized_config['cron_expression'] = trigger_config.get('cron_expression', '')
|
|
sanitized_config['timezone'] = trigger_config.get('timezone', 'UTC')
|
|
elif provider_id == 'composio':
|
|
sanitized_config['trigger_slug'] = trigger_config.get('trigger_slug', '')
|
|
if 'qualified_name' in trigger_config:
|
|
sanitized_config['qualified_name'] = trigger_config['qualified_name']
|
|
|
|
sanitized_trigger = {
|
|
'name': trigger.get('name'),
|
|
'description': trigger.get('description'),
|
|
'trigger_type': trigger.get('trigger_type'),
|
|
'is_active': trigger.get('is_active', True),
|
|
'config': sanitized_config
|
|
}
|
|
sanitized_triggers.append(sanitized_trigger)
|
|
|
|
sanitized = {
|
|
'system_prompt': config.get('system_prompt', ''),
|
|
'model': config.get('model'),
|
|
'tools': {
|
|
'agentpress': sanitized_agentpress,
|
|
'mcp': config.get('tools', {}).get('mcp', []),
|
|
'custom_mcp': []
|
|
},
|
|
'workflows': sanitized_workflows,
|
|
'triggers': sanitized_triggers,
|
|
'metadata': {
|
|
'avatar': config.get('metadata', {}).get('avatar'),
|
|
'avatar_color': config.get('metadata', {}).get('avatar_color')
|
|
}
|
|
}
|
|
|
|
custom_mcps = config.get('tools', {}).get('custom_mcp', [])
|
|
for mcp in custom_mcps:
|
|
if isinstance(mcp, dict):
|
|
mcp_name = mcp.get('name', '')
|
|
mcp_type = mcp.get('type', 'sse')
|
|
|
|
sanitized_mcp = {
|
|
'name': mcp_name,
|
|
'type': mcp_type,
|
|
'display_name': mcp.get('display_name') or mcp_name,
|
|
'enabledTools': mcp.get('enabledTools', [])
|
|
}
|
|
|
|
if mcp_type == 'pipedream':
|
|
original_config = mcp.get('config', {})
|
|
app_slug = (
|
|
mcp.get('app_slug') or
|
|
original_config.get('headers', {}).get('x-pd-app-slug')
|
|
)
|
|
qualified_name = mcp.get('qualifiedName')
|
|
|
|
if not app_slug:
|
|
if qualified_name and qualified_name.startswith('pipedream:'):
|
|
app_slug = qualified_name[10:]
|
|
else:
|
|
app_slug = mcp_name.lower().replace(' ', '').replace('(', '').replace(')', '')
|
|
|
|
if not qualified_name:
|
|
qualified_name = f"pipedream:{app_slug}"
|
|
|
|
sanitized_mcp['qualifiedName'] = qualified_name
|
|
sanitized_mcp['app_slug'] = app_slug
|
|
|
|
sanitized_mcp['config'] = {
|
|
'url': original_config.get('url'),
|
|
'headers': {k: v for k, v in original_config.get('headers', {}).items()
|
|
if k not in ['profile_id', 'x-pd-app-slug']}
|
|
}
|
|
|
|
elif mcp_type == 'composio':
|
|
original_config = mcp.get('config', {})
|
|
qualified_name = (
|
|
mcp.get('mcp_qualified_name') or
|
|
original_config.get('mcp_qualified_name') or
|
|
mcp.get('qualifiedName') or
|
|
original_config.get('qualifiedName')
|
|
)
|
|
toolkit_slug = (
|
|
mcp.get('toolkit_slug') or
|
|
original_config.get('toolkit_slug')
|
|
)
|
|
|
|
if not qualified_name:
|
|
if not toolkit_slug:
|
|
toolkit_slug = mcp_name.lower().replace(' ', '_')
|
|
qualified_name = f"composio.{toolkit_slug}"
|
|
else:
|
|
if not toolkit_slug:
|
|
if qualified_name.startswith('composio.'):
|
|
toolkit_slug = qualified_name[9:]
|
|
else:
|
|
toolkit_slug = mcp_name.lower().replace(' ', '_')
|
|
|
|
sanitized_mcp['mcp_qualified_name'] = qualified_name
|
|
sanitized_mcp['toolkit_slug'] = toolkit_slug
|
|
sanitized_mcp['config'] = {}
|
|
|
|
else:
|
|
qualified_name = mcp.get('qualifiedName')
|
|
if not qualified_name:
|
|
safe_name = mcp_name.replace(' ', '_').lower()
|
|
qualified_name = f"custom_{mcp_type}_{safe_name}"
|
|
|
|
sanitized_mcp['qualifiedName'] = qualified_name
|
|
sanitized_mcp['config'] = {}
|
|
|
|
sanitized['tools']['custom_mcp'].append(sanitized_mcp)
|
|
|
|
return sanitized
|
|
|
|
def _is_suna_default_agent(self, agent: Dict[str, Any]) -> bool:
|
|
metadata = agent.get('metadata', {})
|
|
return metadata.get('is_suna_default', False)
|
|
|
|
async def _save_template(self, template: AgentTemplate) -> None:
|
|
client = await self._db.client
|
|
|
|
template_data = {
|
|
'template_id': template.template_id,
|
|
'creator_id': template.creator_id,
|
|
'name': template.name,
|
|
'description': template.description,
|
|
'config': template.config,
|
|
'tags': template.tags,
|
|
'is_public': template.is_public,
|
|
'marketplace_published_at': template.marketplace_published_at.isoformat() if template.marketplace_published_at else None,
|
|
'download_count': template.download_count,
|
|
'created_at': template.created_at.isoformat(),
|
|
'updated_at': template.updated_at.isoformat(),
|
|
'avatar': template.avatar,
|
|
'avatar_color': template.avatar_color,
|
|
'profile_image_url': template.profile_image_url,
|
|
'icon_name': template.icon_name,
|
|
'icon_color': template.icon_color,
|
|
'icon_background': template.icon_background,
|
|
'metadata': template.metadata
|
|
}
|
|
|
|
await client.table('agent_templates').insert(template_data).execute()
|
|
|
|
def _map_to_template(self, data: Dict[str, Any]) -> AgentTemplate:
|
|
creator_name = data.get('creator_name')
|
|
|
|
return AgentTemplate(
|
|
template_id=data['template_id'],
|
|
creator_id=data['creator_id'],
|
|
name=data['name'],
|
|
description=data.get('description'),
|
|
config=data.get('config', {}),
|
|
tags=data.get('tags', []),
|
|
is_public=data.get('is_public', False),
|
|
is_kortix_team=data.get('is_kortix_team', False),
|
|
marketplace_published_at=datetime.fromisoformat(data['marketplace_published_at'].replace('Z', '+00:00')) if data.get('marketplace_published_at') else None,
|
|
download_count=data.get('download_count', 0),
|
|
created_at=datetime.fromisoformat(data['created_at'].replace('Z', '+00:00')),
|
|
updated_at=datetime.fromisoformat(data['updated_at'].replace('Z', '+00:00')),
|
|
avatar=data.get('avatar'),
|
|
avatar_color=data.get('avatar_color'),
|
|
profile_image_url=data.get('profile_image_url'),
|
|
icon_name=data.get('icon_name'),
|
|
icon_color=data.get('icon_color'),
|
|
icon_background=data.get('icon_background'),
|
|
metadata=data.get('metadata', {}),
|
|
creator_name=creator_name
|
|
)
|
|
|
|
def _generate_share_id(self) -> str:
|
|
alphabet = string.ascii_lowercase + string.digits
|
|
alphabet = alphabet.replace('0', '').replace('o', '').replace('1', '').replace('l', '').replace('i', '')
|
|
return ''.join(secrets.choice(alphabet) for _ in range(10))
|
|
|
|
async def create_share_link(self, template_id: str, creator_id: str) -> str:
|
|
logger.debug(f"Creating share link for template {template_id} by user {creator_id}")
|
|
|
|
template = await self.get_template(template_id)
|
|
if not template:
|
|
raise TemplateNotFoundError("Template not found")
|
|
|
|
if template.creator_id != creator_id:
|
|
raise TemplateAccessDeniedError("You can only create share links for your own templates")
|
|
|
|
client = await self._db.client
|
|
|
|
existing = await client.table('template_share_links')\
|
|
.select('share_id')\
|
|
.eq('template_id', template_id)\
|
|
.eq('created_by', creator_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
if existing and existing.data:
|
|
logger.debug(f"Returning existing share link {existing.data['share_id']} for template {template_id}")
|
|
return existing.data['share_id']
|
|
|
|
max_attempts = 10
|
|
for _ in range(max_attempts):
|
|
share_id = self._generate_share_id()
|
|
|
|
try:
|
|
result = await client.table('template_share_links').insert({
|
|
'share_id': share_id,
|
|
'template_id': template_id,
|
|
'created_by': creator_id
|
|
}).execute()
|
|
|
|
logger.debug(f"Created share link {share_id} for template {template_id}")
|
|
return share_id
|
|
except Exception as e:
|
|
if 'duplicate key' in str(e).lower():
|
|
continue
|
|
raise
|
|
|
|
raise Exception("Failed to generate unique share ID after multiple attempts")
|
|
|
|
async def get_template_by_share_id(self, share_id: str) -> Optional[AgentTemplate]:
|
|
logger.debug(f"Getting template by share ID {share_id}")
|
|
|
|
client = await self._db.client
|
|
|
|
share_link = await client.table('template_share_links')\
|
|
.select('template_id')\
|
|
.eq('share_id', share_id)\
|
|
.maybe_single()\
|
|
.execute()
|
|
|
|
if not share_link or not share_link.data:
|
|
logger.debug(f"Share link {share_id} not found")
|
|
return None
|
|
|
|
current = await client.table('template_share_links')\
|
|
.select('views_count')\
|
|
.eq('share_id', share_id)\
|
|
.single()\
|
|
.execute()
|
|
|
|
views_count = 0
|
|
if current and current.data:
|
|
views_count = (current.data.get('views_count', 0) or 0)
|
|
|
|
await client.table('template_share_links')\
|
|
.update({
|
|
'views_count': views_count + 1,
|
|
'last_viewed_at': datetime.now(timezone.utc).isoformat()
|
|
})\
|
|
.eq('share_id', share_id)\
|
|
.execute()
|
|
|
|
template = await self.get_template(share_link.data['template_id'])
|
|
return template
|
|
|
|
async def get_share_links_for_user(self, user_id: str) -> List[Dict[str, Any]]:
|
|
client = await self._db.client
|
|
|
|
result = await client.table('template_share_links')\
|
|
.select('share_id, template_id, created_at, views_count, last_viewed_at')\
|
|
.eq('created_by', user_id)\
|
|
.order('created_at', desc=True)\
|
|
.execute()
|
|
|
|
return result.data if result.data else []
|
|
|
|
def get_template_service(db_connection: DBConnection) -> TemplateService:
|
|
return TemplateService(db_connection) |