import json from typing import Optional, List, Dict, Any from datetime import datetime, timezone, timedelta from utils.logger import logger from services import redis from run_agent_background import update_agent_run_status # Agent run limits MAX_PARALLEL_AGENT_RUNS = 3 async def _cleanup_redis_response_list(agent_run_id: str): try: response_list_key = f"agent_run:{agent_run_id}:responses" await redis.delete(response_list_key) logger.debug(f"Cleaned up Redis response list for agent run {agent_run_id}") except Exception as e: logger.warning(f"Failed to clean up Redis response list for {agent_run_id}: {str(e)}") async def check_for_active_project_agent_run(client, project_id: str): project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute() project_thread_ids = [t['thread_id'] for t in project_threads.data] if project_thread_ids: active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute() if active_runs.data and len(active_runs.data) > 0: return active_runs.data[0]['id'] return None async def stop_agent_run(db, agent_run_id: str, error_message: Optional[str] = None): logger.info(f"Stopping agent run: {agent_run_id}") client = await db.client final_status = "failed" if error_message else "stopped" response_list_key = f"agent_run:{agent_run_id}:responses" all_responses = [] try: all_responses_json = await redis.lrange(response_list_key, 0, -1) all_responses = [json.loads(r) for r in all_responses_json] logger.info(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}") except Exception as e: logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}") update_success = await update_agent_run_status( client, agent_run_id, final_status, error=error_message, responses=all_responses ) if not update_success: logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}") global_control_channel = f"agent_run:{agent_run_id}:control" try: await redis.publish(global_control_channel, "STOP") logger.debug(f"Published STOP signal to global channel {global_control_channel}") except Exception as e: logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}") try: instance_keys = await redis.keys(f"active_run:*:{agent_run_id}") logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}") for key in instance_keys: parts = key.split(":") if len(parts) == 3: instance_id_from_key = parts[1] instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}" try: await redis.publish(instance_control_channel, "STOP") logger.debug(f"Published STOP signal to instance channel {instance_control_channel}") except Exception as e: logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}") else: logger.warning(f"Unexpected key format found: {key}") await _cleanup_redis_response_list(agent_run_id) except Exception as e: logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}") logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}") async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]: """ Check if the account has reached the limit of 3 parallel agent runs within the past 24 hours. Returns: Dict with 'can_start' (bool), 'running_count' (int), 'running_thread_ids' (list) """ try: # Calculate 24 hours ago twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) twenty_four_hours_ago_iso = twenty_four_hours_ago.isoformat() logger.debug(f"Checking agent run limit for account {account_id} since {twenty_four_hours_ago_iso}") # Get all threads for this account threads_result = await client.table('threads').select('thread_id').eq('account_id', account_id).execute() if not threads_result.data: logger.debug(f"No threads found for account {account_id}") return { 'can_start': True, 'running_count': 0, 'running_thread_ids': [] } thread_ids = [thread['thread_id'] for thread in threads_result.data] logger.debug(f"Found {len(thread_ids)} threads for account {account_id}") # Query for running agent runs within the past 24 hours for these threads running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute() running_runs = running_runs_result.data or [] running_count = len(running_runs) running_thread_ids = [run['thread_id'] for run in running_runs] logger.info(f"Account {account_id} has {running_count} running agent runs in the past 24 hours") return { 'can_start': running_count < MAX_PARALLEL_AGENT_RUNS, 'running_count': running_count, 'running_thread_ids': running_thread_ids } except Exception as e: logger.error(f"Error checking agent run limit for account {account_id}: {str(e)}") # In case of error, allow the run to proceed but log the error return { 'can_start': True, 'running_count': 0, 'running_thread_ids': [] }