suna/backend/agent/utils.py

210 lines
9.0 KiB
Python
Raw Normal View History

2025-07-08 03:31:12 +08:00
import json
from typing import Optional, List, Dict, Any
from datetime import datetime, timezone, timedelta
from utils.cache import Cache
2025-07-08 03:31:12 +08:00
from utils.logger import logger
from utils.config import config
2025-07-08 03:31:12 +08:00
from services import redis
from run_agent_background import update_agent_run_status
async def _cleanup_redis_response_list(agent_run_id: str):
try:
response_list_key = f"agent_run:{agent_run_id}:responses"
await redis.delete(response_list_key)
logger.debug(f"Cleaned up Redis response list for agent run {agent_run_id}")
except Exception as e:
logger.warning(f"Failed to clean up Redis response list for {agent_run_id}: {str(e)}")
2025-07-08 03:31:12 +08:00
async def check_for_active_project_agent_run(client, project_id: str):
project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute()
project_thread_ids = [t['thread_id'] for t in project_threads.data]
if project_thread_ids:
active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute()
if active_runs.data and len(active_runs.data) > 0:
return active_runs.data[0]['id']
return None
async def stop_agent_run(db, agent_run_id: str, error_message: Optional[str] = None):
2025-08-17 10:10:56 +08:00
logger.debug(f"Stopping agent run: {agent_run_id}")
2025-07-08 03:31:12 +08:00
client = await db.client
final_status = "failed" if error_message else "stopped"
response_list_key = f"agent_run:{agent_run_id}:responses"
all_responses = []
try:
all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json]
2025-08-17 10:10:56 +08:00
logger.debug(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}")
2025-07-08 03:31:12 +08:00
except Exception as e:
logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}")
update_success = await update_agent_run_status(
client, agent_run_id, final_status, error=error_message, responses=all_responses
)
if not update_success:
logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}")
global_control_channel = f"agent_run:{agent_run_id}:control"
try:
await redis.publish(global_control_channel, "STOP")
logger.debug(f"Published STOP signal to global channel {global_control_channel}")
except Exception as e:
logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}")
try:
instance_keys = await redis.keys(f"active_run:*:{agent_run_id}")
logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}")
for key in instance_keys:
parts = key.split(":")
if len(parts) == 3:
instance_id_from_key = parts[1]
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}"
try:
await redis.publish(instance_control_channel, "STOP")
logger.debug(f"Published STOP signal to instance channel {instance_control_channel}")
except Exception as e:
logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}")
else:
logger.warning(f"Unexpected key format found: {key}")
2025-07-08 03:31:12 +08:00
await _cleanup_redis_response_list(agent_run_id)
except Exception as e:
logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}")
2025-08-17 10:10:56 +08:00
logger.debug(f"Successfully initiated stop process for agent run: {agent_run_id}")
async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]:
"""
Check if the account has reached the limit of 3 parallel agent runs within the past 24 hours.
Returns:
Dict with 'can_start' (bool), 'running_count' (int), 'running_thread_ids' (list)
"""
try:
result = await Cache.get(f"agent_run_limit:{account_id}")
if result:
return result
# Calculate 24 hours ago
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
twenty_four_hours_ago_iso = twenty_four_hours_ago.isoformat()
logger.debug(f"Checking agent run limit for account {account_id} since {twenty_four_hours_ago_iso}")
# Get all threads for this account
threads_result = await client.table('threads').select('thread_id').eq('account_id', account_id).execute()
if not threads_result.data:
logger.debug(f"No threads found for account {account_id}")
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
thread_ids = [thread['thread_id'] for thread in threads_result.data]
logger.debug(f"Found {len(thread_ids)} threads for account {account_id}")
# Query for running agent runs within the past 24 hours for these threads
running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
running_runs = running_runs_result.data or []
running_count = len(running_runs)
running_thread_ids = [run['thread_id'] for run in running_runs]
2025-08-17 10:10:56 +08:00
logger.debug(f"Account {account_id} has {running_count} running agent runs in the past 24 hours")
result = {
'can_start': running_count < config.MAX_PARALLEL_AGENT_RUNS,
'running_count': running_count,
'running_thread_ids': running_thread_ids
}
await Cache.set(f"agent_run_limit:{account_id}", result)
return result
except Exception as e:
logger.error(f"Error checking agent run limit for account {account_id}: {str(e)}")
# In case of error, allow the run to proceed but log the error
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
2025-08-10 03:45:29 +08:00
async def check_agent_count_limit(client, account_id: str) -> Dict[str, Any]:
try:
# In local mode, allow practically unlimited custom agents
if config.ENV_MODE.value == "local":
return {
'can_create': True,
'current_count': 0, # Return 0 to avoid showing any limit warnings
'limit': 999999, # Practically unlimited
'tier_name': 'local'
}
2025-08-10 03:45:29 +08:00
try:
result = await Cache.get(f"agent_count_limit:{account_id}")
if result:
logger.debug(f"Cache hit for agent count limit: {account_id}")
return result
except Exception as cache_error:
logger.warning(f"Cache read failed for agent count limit {account_id}: {str(cache_error)}")
agents_result = await client.table('agents').select('agent_id, metadata').eq('account_id', account_id).execute()
non_suna_agents = []
for agent in agents_result.data or []:
metadata = agent.get('metadata', {}) or {}
is_suna_default = metadata.get('is_suna_default', False)
if not is_suna_default:
non_suna_agents.append(agent)
current_count = len(non_suna_agents)
logger.debug(f"Account {account_id} has {current_count} custom agents (excluding Suna defaults)")
try:
from services.billing import get_subscription_tier
tier_name = await get_subscription_tier(client, account_id)
logger.debug(f"Account {account_id} subscription tier: {tier_name}")
except Exception as billing_error:
logger.warning(f"Could not get subscription tier for {account_id}: {str(billing_error)}, defaulting to free")
tier_name = 'free'
agent_limit = config.AGENT_LIMITS.get(tier_name, config.AGENT_LIMITS['free'])
can_create = current_count < agent_limit
result = {
'can_create': can_create,
'current_count': current_count,
'limit': agent_limit,
'tier_name': tier_name
}
try:
await Cache.set(f"agent_count_limit:{account_id}", result, ttl=300)
except Exception as cache_error:
logger.warning(f"Cache write failed for agent count limit {account_id}: {str(cache_error)}")
2025-08-17 10:10:56 +08:00
logger.debug(f"Account {account_id} has {current_count}/{agent_limit} agents (tier: {tier_name}) - can_create: {can_create}")
2025-08-10 03:45:29 +08:00
return result
except Exception as e:
logger.error(f"Error checking agent count limit for account {account_id}: {str(e)}", exc_info=True)
return {
'can_create': True,
'current_count': 0,
'limit': config.AGENT_LIMITS['free'],
'tier_name': 'free'
}