import json from typing import Optional, List, Dict, Any from datetime import datetime, timezone, timedelta from utils.cache import Cache from utils.logger import logger from utils.config import config from services import redis from run_agent_background import update_agent_run_status async def _cleanup_redis_response_list(agent_run_id: str): try: response_list_key = f"agent_run:{agent_run_id}:responses" await redis.delete(response_list_key) logger.debug(f"Cleaned up Redis response list for agent run {agent_run_id}") except Exception as e: logger.warning(f"Failed to clean up Redis response list for {agent_run_id}: {str(e)}") async def check_for_active_project_agent_run(client, project_id: str): project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute() project_thread_ids = [t['thread_id'] for t in project_threads.data] if project_thread_ids: active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute() if active_runs.data and len(active_runs.data) > 0: return active_runs.data[0]['id'] return None async def stop_agent_run(db, agent_run_id: str, error_message: Optional[str] = None): logger.info(f"Stopping agent run: {agent_run_id}") client = await db.client final_status = "failed" if error_message else "stopped" response_list_key = f"agent_run:{agent_run_id}:responses" all_responses = [] try: all_responses_json = await redis.lrange(response_list_key, 0, -1) all_responses = [json.loads(r) for r in all_responses_json] logger.info(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}") except Exception as e: logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}") update_success = await update_agent_run_status( client, agent_run_id, final_status, error=error_message, responses=all_responses ) if not update_success: logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}") global_control_channel = f"agent_run:{agent_run_id}:control" try: await redis.publish(global_control_channel, "STOP") logger.debug(f"Published STOP signal to global channel {global_control_channel}") except Exception as e: logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}") try: instance_keys = await redis.keys(f"active_run:*:{agent_run_id}") logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}") for key in instance_keys: parts = key.split(":") if len(parts) == 3: instance_id_from_key = parts[1] instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}" try: await redis.publish(instance_control_channel, "STOP") logger.debug(f"Published STOP signal to instance channel {instance_control_channel}") except Exception as e: logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}") else: logger.warning(f"Unexpected key format found: {key}") await _cleanup_redis_response_list(agent_run_id) except Exception as e: logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}") logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}") async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]: """ Check if the account has reached the limit of 3 parallel agent runs within the past 24 hours. Returns: Dict with 'can_start' (bool), 'running_count' (int), 'running_thread_ids' (list) """ try: result = await Cache.get(f"agent_run_limit:{account_id}") if result: return result # Calculate 24 hours ago twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) twenty_four_hours_ago_iso = twenty_four_hours_ago.isoformat() logger.debug(f"Checking agent run limit for account {account_id} since {twenty_four_hours_ago_iso}") # Get all threads for this account threads_result = await client.table('threads').select('thread_id').eq('account_id', account_id).execute() if not threads_result.data: logger.debug(f"No threads found for account {account_id}") return { 'can_start': True, 'running_count': 0, 'running_thread_ids': [] } thread_ids = [thread['thread_id'] for thread in threads_result.data] logger.debug(f"Found {len(thread_ids)} threads for account {account_id}") # Query for running agent runs within the past 24 hours for these threads running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute() running_runs = running_runs_result.data or [] running_count = len(running_runs) running_thread_ids = [run['thread_id'] for run in running_runs] logger.info(f"Account {account_id} has {running_count} running agent runs in the past 24 hours") result = { 'can_start': running_count < config.MAX_PARALLEL_AGENT_RUNS, 'running_count': running_count, 'running_thread_ids': running_thread_ids } await Cache.set(f"agent_run_limit:{account_id}", result) return result except Exception as e: logger.error(f"Error checking agent run limit for account {account_id}: {str(e)}") # In case of error, allow the run to proceed but log the error return { 'can_start': True, 'running_count': 0, 'running_thread_ids': [] } async def check_agent_count_limit(client, account_id: str) -> Dict[str, Any]: try: try: result = await Cache.get(f"agent_count_limit:{account_id}") if result: logger.debug(f"Cache hit for agent count limit: {account_id}") return result except Exception as cache_error: logger.warning(f"Cache read failed for agent count limit {account_id}: {str(cache_error)}") agents_result = await client.table('agents').select('agent_id, metadata').eq('account_id', account_id).execute() non_suna_agents = [] for agent in agents_result.data or []: metadata = agent.get('metadata', {}) or {} is_suna_default = metadata.get('is_suna_default', False) if not is_suna_default: non_suna_agents.append(agent) current_count = len(non_suna_agents) logger.debug(f"Account {account_id} has {current_count} custom agents (excluding Suna defaults)") try: from services.billing import get_subscription_tier tier_name = await get_subscription_tier(client, account_id) logger.debug(f"Account {account_id} subscription tier: {tier_name}") except Exception as billing_error: logger.warning(f"Could not get subscription tier for {account_id}: {str(billing_error)}, defaulting to free") tier_name = 'free' agent_limit = config.AGENT_LIMITS.get(tier_name, config.AGENT_LIMITS['free']) can_create = current_count < agent_limit result = { 'can_create': can_create, 'current_count': current_count, 'limit': agent_limit, 'tier_name': tier_name } try: await Cache.set(f"agent_count_limit:{account_id}", result, ttl=300) except Exception as cache_error: logger.warning(f"Cache write failed for agent count limit {account_id}: {str(cache_error)}") logger.info(f"Account {account_id} has {current_count}/{agent_limit} agents (tier: {tier_name}) - can_create: {can_create}") return result except Exception as e: logger.error(f"Error checking agent count limit for account {account_id}: {str(e)}", exc_info=True) return { 'can_create': True, 'current_count': 0, 'limit': config.AGENT_LIMITS['free'], 'tier_name': 'free' }