mirror of https://github.com/kortix-ai/suna.git
437 lines
19 KiB
Python
437 lines
19 KiB
Python
from typing import List, Dict, Any, Optional
|
|
from core.utils.pagination import PaginationService, PaginationParams, PaginatedResponse
|
|
from core.utils.logger import logger
|
|
from .config_helper import extract_agent_config
|
|
from core.utils.query_utils import batch_query_in
|
|
|
|
|
|
class AgentFilters:
|
|
def __init__(
|
|
self,
|
|
search: Optional[str] = None,
|
|
has_default: Optional[bool] = None,
|
|
has_mcp_tools: Optional[bool] = None,
|
|
has_agentpress_tools: Optional[bool] = None,
|
|
tools: Optional[List[str]] = None,
|
|
content_type: Optional[str] = None,
|
|
sort_by: str = "created_at",
|
|
sort_order: str = "desc"
|
|
):
|
|
self.search = search
|
|
self.has_default = has_default
|
|
self.has_mcp_tools = has_mcp_tools
|
|
self.has_agentpress_tools = has_agentpress_tools
|
|
self.tools = tools or []
|
|
self.content_type = content_type
|
|
self.sort_by = sort_by
|
|
self.sort_order = sort_order
|
|
|
|
|
|
class AgentService:
|
|
def __init__(self, db_client):
|
|
self.db = db_client
|
|
|
|
async def get_agents_paginated(
|
|
self,
|
|
user_id: str,
|
|
pagination_params: PaginationParams,
|
|
filters: AgentFilters
|
|
) -> PaginatedResponse[Dict[str, Any]]:
|
|
try:
|
|
logger.debug(f"Fetching content for user {user_id} with filters: {filters.__dict__}")
|
|
|
|
# Handle unified agent/template content types
|
|
if filters.content_type == "templates":
|
|
return await self._get_user_templates_paginated(user_id, pagination_params, filters)
|
|
else:
|
|
# Default behavior: show only agents (content_type == "agents" or None)
|
|
return await self._get_only_agents_paginated(user_id, pagination_params, filters)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching content for user {user_id}: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def _get_only_agents_paginated(
|
|
self,
|
|
user_id: str,
|
|
pagination_params: PaginationParams,
|
|
filters: AgentFilters
|
|
) -> PaginatedResponse[Dict[str, Any]]:
|
|
"""Get only agents (not templates) with pagination"""
|
|
base_query = self._build_base_query(user_id, filters)
|
|
count_query = self._build_count_query(user_id, filters)
|
|
|
|
needs_post_processing = (
|
|
filters.has_mcp_tools is not None or
|
|
filters.has_agentpress_tools is not None or
|
|
len(filters.tools) > 0 or
|
|
filters.sort_by == "tools_count"
|
|
)
|
|
|
|
if needs_post_processing:
|
|
return await self._get_agents_with_complex_filtering(
|
|
user_id, pagination_params, filters, base_query
|
|
)
|
|
else:
|
|
return await self._get_agents_database_paginated(
|
|
base_query, count_query, pagination_params, filters
|
|
)
|
|
|
|
async def _get_user_templates_paginated(
|
|
self,
|
|
user_id: str,
|
|
pagination_params: PaginationParams,
|
|
filters: AgentFilters
|
|
) -> PaginatedResponse[Dict[str, Any]]:
|
|
"""Get user's templates with pagination (both public and private)"""
|
|
try:
|
|
# Build template queries
|
|
base_query = self.db.table('agent_templates').select('*').eq('creator_id', user_id)
|
|
count_query = self.db.table('agent_templates').select('*', count='exact').eq('creator_id', user_id)
|
|
|
|
# Apply search filter
|
|
if filters.search:
|
|
search_term = f"%{filters.search}%"
|
|
base_query = base_query.or_(f"name.ilike.{search_term},description.ilike.{search_term}")
|
|
count_query = count_query.or_(f"name.ilike.{search_term},description.ilike.{search_term}")
|
|
|
|
# Apply sorting for templates
|
|
if filters.sort_by == "name":
|
|
base_query = base_query.order('name', desc=(filters.sort_order == "desc"))
|
|
elif filters.sort_by == "download_count":
|
|
base_query = base_query.order('download_count', desc=(filters.sort_order == "desc"))
|
|
base_query = base_query.order('created_at', desc=True) # Secondary sort
|
|
else:
|
|
# Default to created_at
|
|
base_query = base_query.order('created_at', desc=(filters.sort_order == "desc"))
|
|
|
|
# Use pagination service
|
|
paginated_result = await PaginationService.paginate_database_query(
|
|
base_query=base_query,
|
|
params=pagination_params,
|
|
count_query=count_query
|
|
)
|
|
|
|
# Transform template data to match agent response format
|
|
template_responses = []
|
|
for template_data in paginated_result.data:
|
|
transformed_template = await self._transform_template_to_agent_format(template_data)
|
|
template_responses.append(transformed_template)
|
|
|
|
return PaginatedResponse(
|
|
data=template_responses,
|
|
pagination=paginated_result.pagination
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching templates for user {user_id}: {e}", exc_info=True)
|
|
raise
|
|
|
|
def _build_base_query(self, user_id: str, filters: AgentFilters):
|
|
query = self.db.table('agents').select('*').eq("account_id", user_id)
|
|
|
|
if filters.search:
|
|
search_term = f"%{filters.search}%"
|
|
query = query.or_(f"name.ilike.{search_term},description.ilike.{search_term}")
|
|
|
|
if filters.has_default is not None:
|
|
query = query.eq("is_default", filters.has_default)
|
|
|
|
if filters.sort_by != "tools_count":
|
|
sort_column = filters.sort_by if filters.sort_by in ["name", "created_at", "updated_at"] else "created_at"
|
|
query = query.order(sort_column, desc=(filters.sort_order == "desc"))
|
|
|
|
return query
|
|
|
|
def _build_count_query(self, user_id: str, filters: AgentFilters):
|
|
query = self.db.table('agents').select('*', count='exact').eq("account_id", user_id)
|
|
|
|
if filters.search:
|
|
search_term = f"%{filters.search}%"
|
|
query = query.or_(f"name.ilike.{search_term},description.ilike.{search_term}")
|
|
|
|
if filters.has_default is not None:
|
|
query = query.eq("is_default", filters.has_default)
|
|
|
|
return query
|
|
|
|
async def _get_agents_database_paginated(
|
|
self,
|
|
base_query,
|
|
count_query,
|
|
pagination_params: PaginationParams,
|
|
filters: AgentFilters
|
|
) -> PaginatedResponse[Dict[str, Any]]:
|
|
paginated_result = await PaginationService.paginate_database_query(
|
|
base_query=base_query,
|
|
params=pagination_params,
|
|
count_query=count_query
|
|
)
|
|
|
|
agent_responses = []
|
|
for agent_data in paginated_result.data:
|
|
agent_response = await self._transform_agent_data(agent_data)
|
|
agent_responses.append(agent_response)
|
|
|
|
return PaginatedResponse(
|
|
data=agent_responses,
|
|
pagination=paginated_result.pagination
|
|
)
|
|
|
|
async def _get_agents_with_complex_filtering(
|
|
self,
|
|
user_id: str,
|
|
pagination_params: PaginationParams,
|
|
filters: AgentFilters,
|
|
base_query
|
|
) -> PaginatedResponse[Dict[str, Any]]:
|
|
all_agents_result = await base_query.execute()
|
|
all_agents = all_agents_result.data or []
|
|
|
|
if not all_agents:
|
|
return PaginatedResponse(
|
|
data=[],
|
|
pagination=PaginationService.PaginationMeta(
|
|
current_page=pagination_params.page,
|
|
page_size=pagination_params.page_size,
|
|
total_items=0,
|
|
total_pages=0,
|
|
has_next=False,
|
|
has_previous=False
|
|
)
|
|
)
|
|
|
|
version_map = await self._load_agent_versions_batch(all_agents)
|
|
|
|
filtered_agents = []
|
|
for agent_data in all_agents:
|
|
if await self._passes_complex_filters(agent_data, version_map, filters):
|
|
filtered_agents.append(agent_data)
|
|
|
|
if filters.sort_by == "tools_count":
|
|
filtered_agents = await self._sort_by_tools_count(filtered_agents, version_map)
|
|
elif filters.sort_order == "desc":
|
|
filtered_agents.reverse()
|
|
|
|
# Transform to proper response format
|
|
agent_responses = []
|
|
for agent_data in filtered_agents:
|
|
agent_response = await self._transform_agent_data(agent_data, version_map.get(agent_data['agent_id']))
|
|
agent_responses.append(agent_response)
|
|
|
|
return await PaginationService.paginate_filtered_dataset(
|
|
all_items=agent_responses,
|
|
params=pagination_params
|
|
)
|
|
|
|
async def _load_agent_versions_batch(self, agents: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
version_map = {}
|
|
version_ids = list({agent['current_version_id'] for agent in agents if agent.get('current_version_id')})
|
|
|
|
if version_ids:
|
|
try:
|
|
versions_data = await batch_query_in(
|
|
client=self.db,
|
|
table_name='agent_versions',
|
|
select_fields='version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config',
|
|
in_field='version_id',
|
|
in_values=version_ids
|
|
)
|
|
|
|
for row in versions_data:
|
|
config = row.get('config') or {}
|
|
tools = config.get('tools') or {}
|
|
version_dict = {
|
|
'version_id': row['version_id'],
|
|
'agent_id': row['agent_id'],
|
|
'version_number': row['version_number'],
|
|
'version_name': row['version_name'],
|
|
'system_prompt': config.get('system_prompt', ''),
|
|
'model': config.get('model'),
|
|
'configured_mcps': tools.get('mcp', []),
|
|
'custom_mcps': tools.get('custom_mcp', []),
|
|
'agentpress_tools': tools.get('agentpress', {}),
|
|
'is_active': row.get('is_active', False),
|
|
'created_at': row.get('created_at'),
|
|
'updated_at': row.get('updated_at') or row.get('created_at'),
|
|
'created_by': row.get('created_by'),
|
|
'config': config # Include the full config for compatibility
|
|
}
|
|
version_map[row['agent_id']] = version_dict
|
|
except Exception as e:
|
|
logger.warning(f"Failed to batch load agent versions: {e}")
|
|
|
|
return version_map
|
|
|
|
async def _passes_complex_filters(
|
|
self,
|
|
agent_data: Dict[str, Any],
|
|
version_map: Dict[str, Dict[str, Any]],
|
|
filters: AgentFilters
|
|
) -> bool:
|
|
version_data = version_map.get(agent_data['agent_id'])
|
|
agent_config = extract_agent_config(agent_data, version_data)
|
|
|
|
configured_mcps = agent_config['configured_mcps']
|
|
agentpress_tools = agent_config['agentpress_tools']
|
|
|
|
if filters.has_mcp_tools is not None:
|
|
has_mcp = bool(configured_mcps and len(configured_mcps) > 0)
|
|
if filters.has_mcp_tools != has_mcp:
|
|
return False
|
|
|
|
if filters.has_agentpress_tools is not None:
|
|
has_enabled_tools = any(
|
|
tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False)
|
|
for tool_data in agentpress_tools.values()
|
|
)
|
|
if filters.has_agentpress_tools != has_enabled_tools:
|
|
return False
|
|
|
|
if filters.tools:
|
|
agent_tools = set()
|
|
|
|
for mcp in configured_mcps:
|
|
if isinstance(mcp, dict) and 'name' in mcp:
|
|
agent_tools.add(f"mcp:{mcp['name']}")
|
|
|
|
for tool_name, tool_data in agentpress_tools.items():
|
|
if tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False):
|
|
agent_tools.add(f"agentpress:{tool_name}")
|
|
|
|
if not any(tool in agent_tools for tool in filters.tools):
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _sort_by_tools_count(
|
|
self,
|
|
agents: List[Dict[str, Any]],
|
|
version_map: Dict[str, Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
def get_tools_count(agent_data):
|
|
version_data = version_map.get(agent_data['agent_id'])
|
|
agent_config = extract_agent_config(agent_data, version_data)
|
|
|
|
configured_mcps = agent_config['configured_mcps']
|
|
agentpress_tools = agent_config['agentpress_tools']
|
|
|
|
mcp_count = len(configured_mcps) if configured_mcps else 0
|
|
agentpress_count = sum(
|
|
1 for tool_data in agentpress_tools.values()
|
|
if tool_data and isinstance(tool_data, dict) and tool_data.get('enabled', False)
|
|
) if agentpress_tools else 0
|
|
|
|
return mcp_count + agentpress_count
|
|
|
|
return sorted(agents, key=get_tools_count, reverse=True)
|
|
|
|
async def _transform_agent_data(
|
|
self,
|
|
agent_data: Dict[str, Any],
|
|
version_data: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
agent_config = extract_agent_config(agent_data, version_data)
|
|
|
|
system_prompt = agent_config['system_prompt']
|
|
configured_mcps = agent_config['configured_mcps']
|
|
custom_mcps = agent_config['custom_mcps']
|
|
agentpress_tools = agent_config['agentpress_tools']
|
|
current_version = agent_config.get('current_version')
|
|
|
|
return {
|
|
"agent_id": agent_data['agent_id'],
|
|
"name": agent_data['name'],
|
|
"system_prompt": system_prompt,
|
|
"configured_mcps": configured_mcps,
|
|
"custom_mcps": custom_mcps,
|
|
"agentpress_tools": agentpress_tools,
|
|
"is_default": agent_data.get('is_default', False),
|
|
"is_public": agent_data.get('is_public', False),
|
|
"tags": agent_data.get('tags', []),
|
|
"icon_name": agent_config.get('icon_name'),
|
|
"icon_color": agent_config.get('icon_color'),
|
|
"icon_background": agent_config.get('icon_background'),
|
|
"created_at": agent_data['created_at'],
|
|
"updated_at": agent_data['updated_at'],
|
|
"current_version_id": agent_data.get('current_version_id'),
|
|
"version_count": agent_data.get('version_count', 1),
|
|
"current_version": current_version,
|
|
"metadata": agent_data.get('metadata')
|
|
}
|
|
|
|
async def _transform_template_to_agent_format(self, template_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
creator_name = None
|
|
if template_data.get('creator_id'):
|
|
try:
|
|
creator_result = await self.db.schema('basejump').from_('accounts').select('name, slug').eq('id', template_data['creator_id']).single().execute()
|
|
if creator_result.data:
|
|
creator_name = creator_result.data.get('name') or creator_result.data.get('slug')
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch creator name for template {template_data.get('template_id')}: {e}")
|
|
|
|
return {
|
|
"agent_id": template_data.get('template_id', ''),
|
|
"name": template_data.get('name', ''),
|
|
"system_prompt": template_data.get('system_prompt', ''),
|
|
"configured_mcps": template_data.get('mcp_requirements', []),
|
|
"custom_mcps": [],
|
|
"agentpress_tools": template_data.get('agentpress_tools', {}),
|
|
"is_default": False,
|
|
"icon_name": template_data.get('icon_name'),
|
|
"icon_color": template_data.get('icon_color'),
|
|
"icon_background": template_data.get('icon_background'),
|
|
"created_at": template_data.get('created_at', ''),
|
|
"updated_at": template_data.get('updated_at'),
|
|
"is_public": template_data.get('is_public', False),
|
|
"tags": template_data.get('tags', []),
|
|
"current_version_id": None,
|
|
"version_count": 0,
|
|
"current_version": None,
|
|
"metadata": {
|
|
**(template_data.get('metadata', {})),
|
|
"is_template": True,
|
|
"creator_name": creator_name
|
|
},
|
|
|
|
"template_id": template_data.get('template_id'),
|
|
"mcp_requirements": template_data.get('mcp_requirements', []),
|
|
"model": template_data.get('metadata', {}).get('model'),
|
|
"marketplace_published_at": template_data.get('marketplace_published_at'),
|
|
"download_count": template_data.get('download_count', 0),
|
|
"creator_name": creator_name,
|
|
"creator_id": template_data.get('creator_id'),
|
|
"is_kortix_team": template_data.get('is_kortix_team', False)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error transforming template data: {e}", exc_info=True)
|
|
return {
|
|
"agent_id": template_data.get('template_id', 'unknown'),
|
|
"name": template_data.get('name', 'Unknown Template'),
|
|
"system_prompt": template_data.get('system_prompt', ''),
|
|
"configured_mcps": [],
|
|
"custom_mcps": [],
|
|
"agentpress_tools": {},
|
|
"is_default": False,
|
|
"icon_name": None,
|
|
"icon_color": None,
|
|
"icon_background": None,
|
|
"created_at": template_data.get('created_at', ''),
|
|
"updated_at": template_data.get('updated_at'),
|
|
"is_public": template_data.get('is_public', False),
|
|
"tags": [],
|
|
"current_version_id": None,
|
|
"version_count": 0,
|
|
"current_version": None,
|
|
"metadata": {"is_template": True, "transform_error": True},
|
|
|
|
"template_id": template_data.get('template_id', 'unknown'),
|
|
"mcp_requirements": [],
|
|
"model": None,
|
|
"marketplace_published_at": template_data.get('marketplace_published_at'),
|
|
"download_count": 0,
|
|
"creator_name": None,
|
|
"creator_id": template_data.get('creator_id'),
|
|
"is_kortix_team": False
|
|
} |