mirror of https://github.com/kortix-ai/suna.git
348 lines
14 KiB
Python
348 lines
14 KiB
Python
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, List, Any, Optional
|
|
from uuid import uuid4
|
|
|
|
from services.supabase import DBConnection
|
|
from utils.logger import logger
|
|
from .template_service import AgentTemplate, MCPRequirementValue, ConfigType, ProfileId, QualifiedName
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentInstance:
|
|
instance_id: str
|
|
account_id: str
|
|
name: str
|
|
template_id: Optional[str] = None
|
|
description: Optional[str] = None
|
|
credential_mappings: Dict[QualifiedName, ProfileId] = field(default_factory=dict)
|
|
custom_system_prompt: Optional[str] = None
|
|
is_active: bool = True
|
|
is_default: bool = False
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
avatar: Optional[str] = None
|
|
avatar_color: Optional[str] = None
|
|
|
|
@dataclass
|
|
class TemplateInstallationRequest:
|
|
template_id: str
|
|
account_id: str
|
|
instance_name: Optional[str] = None
|
|
custom_system_prompt: Optional[str] = None
|
|
profile_mappings: Optional[Dict[QualifiedName, ProfileId]] = None
|
|
custom_mcp_configs: Optional[Dict[QualifiedName, ConfigType]] = None
|
|
|
|
@dataclass
|
|
class TemplateInstallationResult:
|
|
status: str
|
|
instance_id: Optional[str] = None
|
|
name: Optional[str] = None
|
|
missing_regular_credentials: List[Dict[str, Any]] = field(default_factory=list)
|
|
missing_custom_configs: List[Dict[str, Any]] = field(default_factory=list)
|
|
template_info: Optional[Dict[str, Any]] = None
|
|
|
|
class TemplateInstallationError(Exception):
|
|
pass
|
|
|
|
class InvalidCredentialError(Exception):
|
|
pass
|
|
|
|
class InstallationService:
|
|
def __init__(self, db_connection: DBConnection):
|
|
self._db = db_connection
|
|
|
|
async def install_template(self, request: TemplateInstallationRequest) -> TemplateInstallationResult:
|
|
logger.info(f"Installing template {request.template_id} for user {request.account_id}")
|
|
|
|
template = await self._get_template(request.template_id)
|
|
if not template:
|
|
raise TemplateInstallationError("Template not found")
|
|
|
|
await self._validate_access(template, request.account_id)
|
|
|
|
mcp_requirements = template.mcp_requirements
|
|
|
|
if not request.profile_mappings:
|
|
request.profile_mappings = await self._auto_map_profiles(
|
|
mcp_requirements,
|
|
request.account_id
|
|
)
|
|
|
|
missing_profiles, missing_configs = await self._validate_installation_requirements(
|
|
mcp_requirements,
|
|
request.profile_mappings,
|
|
request.custom_mcp_configs
|
|
)
|
|
|
|
if missing_profiles or missing_configs:
|
|
return TemplateInstallationResult(
|
|
status='configs_required',
|
|
missing_regular_credentials=missing_profiles,
|
|
missing_custom_configs=missing_configs,
|
|
template_info={
|
|
'template_id': template.template_id,
|
|
'name': template.name,
|
|
'description': template.description
|
|
}
|
|
)
|
|
|
|
agent_config = await self._build_agent_config(
|
|
template,
|
|
request,
|
|
mcp_requirements
|
|
)
|
|
|
|
agent_id = await self._create_agent(
|
|
template,
|
|
request,
|
|
agent_config
|
|
)
|
|
|
|
await self._create_initial_version(
|
|
agent_id,
|
|
request.account_id,
|
|
agent_config,
|
|
request.custom_system_prompt or template.system_prompt
|
|
)
|
|
|
|
await self._increment_download_count(template.template_id)
|
|
|
|
agent_name = request.instance_name or f"{template.name} (from marketplace)"
|
|
logger.info(f"Successfully installed template {template.template_id} as agent {agent_id}")
|
|
|
|
return TemplateInstallationResult(
|
|
status='installed',
|
|
instance_id=agent_id,
|
|
name=agent_name
|
|
)
|
|
|
|
async def _get_template(self, template_id: str) -> Optional[AgentTemplate]:
|
|
from .template_service import get_template_service
|
|
template_service = get_template_service(self._db)
|
|
return await template_service.get_template(template_id)
|
|
|
|
async def _validate_access(self, template: AgentTemplate, user_id: str) -> None:
|
|
if template.creator_id != user_id and not template.is_public:
|
|
raise TemplateInstallationError("Access denied to template")
|
|
|
|
async def _auto_map_profiles(
|
|
self,
|
|
requirements: List[MCPRequirementValue],
|
|
account_id: str
|
|
) -> Dict[QualifiedName, ProfileId]:
|
|
profile_mappings = {}
|
|
|
|
for req in requirements:
|
|
if not req.is_custom():
|
|
from credentials import get_profile_service
|
|
profile_service = get_profile_service(self._db)
|
|
default_profile = await profile_service.get_default_profile(
|
|
account_id, req.qualified_name
|
|
)
|
|
|
|
if default_profile:
|
|
profile_mappings[req.qualified_name] = default_profile.profile_id
|
|
logger.info(f"Auto-mapped {req.qualified_name} to profile {default_profile.profile_id}")
|
|
|
|
return profile_mappings
|
|
|
|
async def _validate_installation_requirements(
|
|
self,
|
|
requirements: List[MCPRequirementValue],
|
|
profile_mappings: Optional[Dict[QualifiedName, ProfileId]],
|
|
custom_configs: Optional[Dict[QualifiedName, ConfigType]]
|
|
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
missing_profiles = []
|
|
missing_configs = []
|
|
|
|
profile_mappings = profile_mappings or {}
|
|
custom_configs = custom_configs or {}
|
|
|
|
for req in requirements:
|
|
if req.is_custom():
|
|
if req.qualified_name not in custom_configs:
|
|
field_descriptions = {}
|
|
for field in req.required_config:
|
|
if field == 'url':
|
|
field_descriptions[field] = {
|
|
'type': 'text',
|
|
'placeholder': 'https://example.com/mcp/endpoint',
|
|
'description': f'The endpoint URL for the {req.display_name} MCP server'
|
|
}
|
|
else:
|
|
field_descriptions[field] = {
|
|
'type': 'text',
|
|
'placeholder': f'Enter {field}',
|
|
'description': f'Required configuration for {field}'
|
|
}
|
|
|
|
missing_configs.append({
|
|
'qualified_name': req.qualified_name,
|
|
'display_name': req.display_name,
|
|
'required_config': req.required_config,
|
|
'custom_type': req.custom_type,
|
|
'field_descriptions': field_descriptions
|
|
})
|
|
else:
|
|
if req.qualified_name not in profile_mappings:
|
|
missing_profiles.append({
|
|
'qualified_name': req.qualified_name,
|
|
'display_name': req.display_name,
|
|
'enabled_tools': req.enabled_tools,
|
|
'required_config': req.required_config
|
|
})
|
|
|
|
return missing_profiles, missing_configs
|
|
|
|
async def _build_agent_config(
|
|
self,
|
|
template: AgentTemplate,
|
|
request: TemplateInstallationRequest,
|
|
requirements: List[MCPRequirementValue]
|
|
) -> Dict[str, Any]:
|
|
agentpress_tools = {}
|
|
template_agentpress = template.agentpress_tools or {}
|
|
for tool_name, tool_config in template_agentpress.items():
|
|
if isinstance(tool_config, dict):
|
|
agentpress_tools[tool_name] = tool_config.get('enabled', True)
|
|
else:
|
|
agentpress_tools[tool_name] = tool_config
|
|
|
|
agent_config = {
|
|
'tools': {
|
|
'agentpress': agentpress_tools,
|
|
'mcp': [],
|
|
'custom_mcp': []
|
|
},
|
|
'metadata': template.config.get('metadata', {}),
|
|
'system_prompt': request.custom_system_prompt or template.system_prompt
|
|
}
|
|
|
|
from credentials import get_profile_service
|
|
profile_service = get_profile_service(self._db)
|
|
|
|
for req in requirements:
|
|
if req.is_custom():
|
|
config = request.custom_mcp_configs.get(req.qualified_name, {})
|
|
|
|
original_name = req.display_name
|
|
if req.qualified_name.startswith('custom_') and '_' in req.qualified_name[7:]:
|
|
parts = req.qualified_name.split('_', 2)
|
|
if len(parts) >= 3:
|
|
original_name = parts[2].replace('_', ' ').title()
|
|
|
|
custom_mcp = {
|
|
'name': original_name,
|
|
'type': req.custom_type or 'sse',
|
|
'config': config,
|
|
'enabledTools': req.enabled_tools
|
|
}
|
|
agent_config['tools']['custom_mcp'].append(custom_mcp)
|
|
else:
|
|
profile_id = request.profile_mappings.get(req.qualified_name)
|
|
if profile_id:
|
|
profile = await profile_service.get_profile(request.account_id, profile_id)
|
|
if profile:
|
|
if req.qualified_name.startswith('pipedream:'):
|
|
app_slug = profile.config.get('app_slug', req.qualified_name.split(':')[1])
|
|
|
|
pipedream_config = {
|
|
'url': 'https://remote.mcp.pipedream.net',
|
|
'headers': {
|
|
'x-pd-app-slug': app_slug
|
|
},
|
|
'profile_id': profile_id
|
|
}
|
|
|
|
mcp_config = {
|
|
'name': req.display_name,
|
|
'type': 'pipedream',
|
|
'config': pipedream_config,
|
|
'enabledTools': req.enabled_tools
|
|
}
|
|
agent_config['tools']['custom_mcp'].append(mcp_config)
|
|
else:
|
|
mcp_config = {
|
|
'name': req.display_name or req.qualified_name,
|
|
'type': 'sse',
|
|
'config': profile.config,
|
|
'enabledTools': req.enabled_tools
|
|
}
|
|
agent_config['tools']['mcp'].append(mcp_config)
|
|
|
|
return agent_config
|
|
|
|
async def _create_agent(
|
|
self,
|
|
template: AgentTemplate,
|
|
request: TemplateInstallationRequest,
|
|
agent_config: Dict[str, Any]
|
|
) -> str:
|
|
agent_id = str(uuid4())
|
|
agent_name = request.instance_name or f"{template.name} (from marketplace)"
|
|
|
|
client = await self._db.client
|
|
|
|
agent_data = {
|
|
'agent_id': agent_id,
|
|
'account_id': request.account_id,
|
|
'name': agent_name,
|
|
'description': template.description,
|
|
'avatar': template.avatar,
|
|
'avatar_color': template.avatar_color,
|
|
'metadata': {
|
|
**template.metadata,
|
|
'created_from_template': template.template_id,
|
|
'template_name': template.name
|
|
},
|
|
'created_at': datetime.now(timezone.utc).isoformat(),
|
|
'updated_at': datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
await client.table('agents').insert(agent_data).execute()
|
|
|
|
logger.info(f"Created agent {agent_id} from template {template.template_id}")
|
|
return agent_id
|
|
|
|
async def _create_initial_version(
|
|
self,
|
|
agent_id: str,
|
|
user_id: str,
|
|
agent_config: Dict[str, Any],
|
|
system_prompt: str
|
|
) -> None:
|
|
try:
|
|
from agent.versioning.facade import version_manager
|
|
|
|
tools = agent_config.get('tools', {})
|
|
configured_mcps = tools.get('mcp', [])
|
|
custom_mcps = tools.get('custom_mcp', [])
|
|
agentpress_tools = tools.get('agentpress', {})
|
|
|
|
await version_manager.create_version(
|
|
agent_id=agent_id,
|
|
user_id=user_id,
|
|
system_prompt=system_prompt,
|
|
configured_mcps=configured_mcps,
|
|
custom_mcps=custom_mcps,
|
|
agentpress_tools=agentpress_tools,
|
|
version_name="v1",
|
|
change_description="Initial version from template"
|
|
)
|
|
|
|
logger.info(f"Created initial version for agent {agent_id}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create initial version for agent {agent_id}: {e}")
|
|
|
|
async def _increment_download_count(self, template_id: str) -> None:
|
|
client = await self._db.client
|
|
try:
|
|
await client.rpc('increment_template_download_count', {
|
|
'template_id_param': template_id
|
|
}).execute()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to increment download count for template {template_id}: {e}")
|
|
|
|
def get_installation_service(db_connection: DBConnection) -> InstallationService:
|
|
return InstallationService(db_connection) |