From 9b3561213d098a0a85296803f4ecad117b1ec8ee Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Wed, 14 May 2025 12:48:02 +0000 Subject: [PATCH] feat(dramatiq): added distributed worker support using dramatiq and rabbitmq --- backend/agent/api.py | 344 ++++++++++++++++---------------- backend/docker-compose.yml | 7 +- backend/run_agent_background.py | 306 ++++++++++++++++++++++++++++ 3 files changed, 481 insertions(+), 176 deletions(-) create mode 100644 backend/run_agent_background.py diff --git a/backend/agent/api.py b/backend/agent/api.py index 618b2b0d..feb42c01 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -21,6 +21,7 @@ from services.billing import check_billing_status from utils.config import config from sandbox.sandbox import create_sandbox, get_or_start_sandbox from services.llm import make_llm_api_call +from run_agent_background import run_agent_background # Initialize shared resources router = APIRouter() @@ -428,19 +429,15 @@ async def start_agent( logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") # Run the agent in the background - task = asyncio.create_task( - run_agent_background( - agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, - project_id=project_id, sandbox=sandbox, - model_name=model_name, # Already resolved above - enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, - stream=body.stream, enable_context_manager=body.enable_context_manager - ) + run_agent_background.send( + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, + model_name=model_name, # Already resolved above + enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, + stream=body.stream, enable_context_manager=body.enable_context_manager ) # Set a callback to clean up Redis instance key when task is done - task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) - return {"agent_run_id": agent_run_id, "status": "running"} @router.post("/agent-run/{agent_run_id}/stop") @@ -661,186 +658,186 @@ async def stream_agent_run( "Access-Control-Allow-Origin": "*" }) -async def run_agent_background( - agent_run_id: str, - thread_id: str, - instance_id: str, # Use the global instance ID passed during initialization - project_id: str, - sandbox, - model_name: str, - enable_thinking: Optional[bool], - reasoning_effort: Optional[str], - stream: bool, - enable_context_manager: bool -): - """Run the agent in the background using Redis for state.""" - logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") - logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") +# @dramatiq.actor +# async def run_agent_background( +# agent_run_id: str, +# thread_id: str, +# instance_id: str, # Use the global instance ID passed during initialization +# project_id: str, +# model_name: str, +# enable_thinking: Optional[bool], +# reasoning_effort: Optional[str], +# stream: bool, +# enable_context_manager: bool +# ): +# """Run the agent in the background using Redis for state.""" +# logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") +# logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") - client = await db.client - start_time = datetime.now(timezone.utc) - total_responses = 0 - pubsub = None - stop_checker = None - stop_signal_received = False +# client = await db.client +# start_time = datetime.now(timezone.utc) +# total_responses = 0 +# pubsub = None +# stop_checker = None +# stop_signal_received = False - # Define Redis keys and channels - response_list_key = f"agent_run:{agent_run_id}:responses" - response_channel = f"agent_run:{agent_run_id}:new_response" - instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" - global_control_channel = f"agent_run:{agent_run_id}:control" - instance_active_key = f"active_run:{instance_id}:{agent_run_id}" +# # Define Redis keys and channels +# response_list_key = f"agent_run:{agent_run_id}:responses" +# response_channel = f"agent_run:{agent_run_id}:new_response" +# instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" +# global_control_channel = f"agent_run:{agent_run_id}:control" +# instance_active_key = f"active_run:{instance_id}:{agent_run_id}" - async def check_for_stop_signal(): - nonlocal stop_signal_received - if not pubsub: return - try: - while not stop_signal_received: - message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) - if message and message.get("type") == "message": - data = message.get("data") - if isinstance(data, bytes): data = data.decode('utf-8') - if data == "STOP": - logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") - stop_signal_received = True - break - # Periodically refresh the active run key TTL - if total_responses % 50 == 0: # Refresh every 50 responses or so - try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) - except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") - await asyncio.sleep(0.1) # Short sleep to prevent tight loop - except asyncio.CancelledError: - logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") - except Exception as e: - logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) - stop_signal_received = True # Stop the run if the checker fails +# async def check_for_stop_signal(): +# nonlocal stop_signal_received +# if not pubsub: return +# try: +# while not stop_signal_received: +# message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) +# if message and message.get("type") == "message": +# data = message.get("data") +# if isinstance(data, bytes): data = data.decode('utf-8') +# if data == "STOP": +# logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") +# stop_signal_received = True +# break +# # Periodically refresh the active run key TTL +# if total_responses % 50 == 0: # Refresh every 50 responses or so +# try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) +# except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") +# await asyncio.sleep(0.1) # Short sleep to prevent tight loop +# except asyncio.CancelledError: +# logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") +# except Exception as e: +# logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) +# stop_signal_received = True # Stop the run if the checker fails - try: - # Setup Pub/Sub listener for control signals - pubsub = await redis.create_pubsub() - await pubsub.subscribe(instance_control_channel, global_control_channel) - logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") - stop_checker = asyncio.create_task(check_for_stop_signal()) +# try: +# # Setup Pub/Sub listener for control signals +# pubsub = await redis.create_pubsub() +# await pubsub.subscribe(instance_control_channel, global_control_channel) +# logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") +# stop_checker = asyncio.create_task(check_for_stop_signal()) - # Ensure active run key exists and has TTL - await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) +# # Ensure active run key exists and has TTL +# await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) - # Initialize agent generator - agent_gen = run_agent( - thread_id=thread_id, project_id=project_id, stream=stream, - thread_manager=thread_manager, model_name=model_name, - enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, - enable_context_manager=enable_context_manager - ) +# # Initialize agent generator +# agent_gen = run_agent( +# thread_id=thread_id, project_id=project_id, stream=stream, +# thread_manager=thread_manager, model_name=model_name, +# enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, +# enable_context_manager=enable_context_manager +# ) - final_status = "running" - error_message = None +# final_status = "running" +# error_message = None - async for response in agent_gen: - if stop_signal_received: - logger.info(f"Agent run {agent_run_id} stopped by signal.") - final_status = "stopped" - break +# async for response in agent_gen: +# if stop_signal_received: +# logger.info(f"Agent run {agent_run_id} stopped by signal.") +# final_status = "stopped" +# break - # Store response in Redis list and publish notification - response_json = json.dumps(response) - await redis.rpush(response_list_key, response_json) - await redis.publish(response_channel, "new") - total_responses += 1 +# # Store response in Redis list and publish notification +# response_json = json.dumps(response) +# await redis.rpush(response_list_key, response_json) +# await redis.publish(response_channel, "new") +# total_responses += 1 - # Check for agent-signaled completion or error - if response.get('type') == 'status': - status_val = response.get('status') - if status_val in ['completed', 'failed', 'stopped']: - logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") - final_status = status_val - if status_val == 'failed' or status_val == 'stopped': - error_message = response.get('message', f"Run ended with status: {status_val}") - break +# # Check for agent-signaled completion or error +# if response.get('type') == 'status': +# status_val = response.get('status') +# if status_val in ['completed', 'failed', 'stopped']: +# logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") +# final_status = status_val +# if status_val == 'failed' or status_val == 'stopped': +# error_message = response.get('message', f"Run ended with status: {status_val}") +# break - # If loop finished without explicit completion/error/stop signal, mark as completed - if final_status == "running": - final_status = "completed" - duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") - completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} - await redis.rpush(response_list_key, json.dumps(completion_message)) - await redis.publish(response_channel, "new") # Notify about the completion message +# # If loop finished without explicit completion/error/stop signal, mark as completed +# if final_status == "running": +# final_status = "completed" +# duration = (datetime.now(timezone.utc) - start_time).total_seconds() +# logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") +# completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} +# await redis.rpush(response_list_key, json.dumps(completion_message)) +# await redis.publish(response_channel, "new") # Notify about the completion message - # Fetch final responses from Redis for DB update - all_responses_json = await redis.lrange(response_list_key, 0, -1) - all_responses = [json.loads(r) for r in all_responses_json] +# # Fetch final responses from Redis for DB update +# all_responses_json = await redis.lrange(response_list_key, 0, -1) +# all_responses = [json.loads(r) for r in all_responses_json] - # Update DB status - await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) +# # Update DB status +# await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) - # Publish final control signal (END_STREAM or ERROR) - control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" - try: - await redis.publish(global_control_channel, control_signal) - # No need to publish to instance channel as the run is ending on this instance - logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") - except Exception as e: - logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") +# # Publish final control signal (END_STREAM or ERROR) +# control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" +# try: +# await redis.publish(global_control_channel, control_signal) +# # No need to publish to instance channel as the run is ending on this instance +# logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") +# except Exception as e: +# logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") - except Exception as e: - error_message = str(e) - traceback_str = traceback.format_exc() - duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") - final_status = "failed" +# except Exception as e: +# error_message = str(e) +# traceback_str = traceback.format_exc() +# duration = (datetime.now(timezone.utc) - start_time).total_seconds() +# logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") +# final_status = "failed" - # Push error message to Redis list - error_response = {"type": "status", "status": "error", "message": error_message} - try: - await redis.rpush(response_list_key, json.dumps(error_response)) - await redis.publish(response_channel, "new") - except Exception as redis_err: - logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") +# # Push error message to Redis list +# error_response = {"type": "status", "status": "error", "message": error_message} +# try: +# await redis.rpush(response_list_key, json.dumps(error_response)) +# await redis.publish(response_channel, "new") +# except Exception as redis_err: +# logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") - # Fetch final responses (including the error) - all_responses = [] - try: - all_responses_json = await redis.lrange(response_list_key, 0, -1) - all_responses = [json.loads(r) for r in all_responses_json] - except Exception as fetch_err: - logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") - all_responses = [error_response] # Use the error message we tried to push +# # Fetch final responses (including the error) +# all_responses = [] +# try: +# all_responses_json = await redis.lrange(response_list_key, 0, -1) +# all_responses = [json.loads(r) for r in all_responses_json] +# except Exception as fetch_err: +# logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") +# all_responses = [error_response] # Use the error message we tried to push - # Update DB status - await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) +# # Update DB status +# await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) - # Publish ERROR signal - try: - await redis.publish(global_control_channel, "ERROR") - logger.debug(f"Published ERROR signal to {global_control_channel}") - except Exception as e: - logger.warning(f"Failed to publish ERROR signal: {str(e)}") +# # Publish ERROR signal +# try: +# await redis.publish(global_control_channel, "ERROR") +# logger.debug(f"Published ERROR signal to {global_control_channel}") +# except Exception as e: +# logger.warning(f"Failed to publish ERROR signal: {str(e)}") - finally: - # Cleanup stop checker task - if stop_checker and not stop_checker.done(): - stop_checker.cancel() - try: await stop_checker - except asyncio.CancelledError: pass - except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") +# finally: +# # Cleanup stop checker task +# if stop_checker and not stop_checker.done(): +# stop_checker.cancel() +# try: await stop_checker +# except asyncio.CancelledError: pass +# except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") - # Close pubsub connection - if pubsub: - try: - await pubsub.unsubscribe() - await pubsub.close() - logger.debug(f"Closed pubsub connection for {agent_run_id}") - except Exception as e: - logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") +# # Close pubsub connection +# if pubsub: +# try: +# await pubsub.unsubscribe() +# await pubsub.close() +# logger.debug(f"Closed pubsub connection for {agent_run_id}") +# except Exception as e: +# logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") - # Set TTL on the response list in Redis - await _cleanup_redis_response_list(agent_run_id) +# # Set TTL on the response list in Redis +# await _cleanup_redis_response_list(agent_run_id) - # Remove the instance-specific active run key - await _cleanup_redis_instance_key(agent_run_id) +# # Remove the instance-specific active run key +# await _cleanup_redis_instance_key(agent_run_id) - logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") +# logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") async def generate_and_update_project_name(project_id: str, prompt: str): """Generates a project name using an LLM and updates the database.""" @@ -1030,16 +1027,13 @@ async def initiate_agent_with_files( logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") # Run agent in background - task = asyncio.create_task( - run_agent_background( - agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, - project_id=project_id, sandbox=sandbox, - model_name=model_name, # Already resolved above - enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, - stream=stream, enable_context_manager=enable_context_manager - ) + run_agent_background.send( + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, + model_name=model_name, # Already resolved above + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, + stream=stream, enable_context_manager=enable_context_manager ) - task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) return {"thread_id": thread_id, "agent_run_id": agent_run_id} diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index 4220dc57..ff409227 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -73,9 +73,14 @@ services: cpus: '1' memory: 8G + rabbitmq: + image: rabbitmq + ports: + - "127.0.0.1:5672:5672" + networks: app-network: driver: bridge volumes: - redis_data: \ No newline at end of file + redis_data: diff --git a/backend/run_agent_background.py b/backend/run_agent_background.py new file mode 100644 index 00000000..0c774b2b --- /dev/null +++ b/backend/run_agent_background.py @@ -0,0 +1,306 @@ +import asyncio +import json +import traceback +from datetime import datetime, timezone +from typing import Optional +from services import redis +from agent.run import run_agent +from utils.logger import logger +import dramatiq +import uuid +from agentpress.thread_manager import ThreadManager +from services.supabase import DBConnection +from services import redis +from dramatiq.brokers.rabbitmq import RabbitmqBroker + +rabbitmq_broker = RabbitmqBroker(host="localhost", port="5672", middleware=[dramatiq.middleware.AsyncIO()]) +dramatiq.set_broker(rabbitmq_broker) + +_initialized = False +db = DBConnection() +thread_manager = None +instance_id = "single" + +async def initialize(): + """Initialize the agent API with resources from the main API.""" + global thread_manager, db, instance_id, _initialized + if _initialized: + return + + # Use provided instance_id or generate a new one + if not instance_id: + # Generate instance ID + instance_id = str(uuid.uuid4())[:8] + await redis.initialize_async() + await db.initialize() + thread_manager = ThreadManager() + + _initialized = True + logger.info(f"Initialized agent API with instance ID: {instance_id}") + + +@dramatiq.actor +async def run_agent_background( + agent_run_id: str, + thread_id: str, + instance_id: str, # Use the global instance ID passed during initialization + project_id: str, + model_name: str, + enable_thinking: Optional[bool], + reasoning_effort: Optional[str], + stream: bool, + enable_context_manager: bool +): + """Run the agent in the background using Redis for state.""" + await initialize() + + logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") + logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") + + client = await db.client + start_time = datetime.now(timezone.utc) + total_responses = 0 + pubsub = None + stop_checker = None + stop_signal_received = False + + # Define Redis keys and channels + response_list_key = f"agent_run:{agent_run_id}:responses" + response_channel = f"agent_run:{agent_run_id}:new_response" + instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" + global_control_channel = f"agent_run:{agent_run_id}:control" + instance_active_key = f"active_run:{instance_id}:{agent_run_id}" + + async def check_for_stop_signal(): + nonlocal stop_signal_received + if not pubsub: return + try: + while not stop_signal_received: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) + if message and message.get("type") == "message": + data = message.get("data") + if isinstance(data, bytes): data = data.decode('utf-8') + if data == "STOP": + logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") + stop_signal_received = True + break + # Periodically refresh the active run key TTL + if total_responses % 50 == 0: # Refresh every 50 responses or so + try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) + except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") + await asyncio.sleep(0.1) # Short sleep to prevent tight loop + except asyncio.CancelledError: + logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") + except Exception as e: + logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) + stop_signal_received = True # Stop the run if the checker fails + + try: + # Setup Pub/Sub listener for control signals + pubsub = await redis.create_pubsub() + await pubsub.subscribe(instance_control_channel, global_control_channel) + logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") + stop_checker = asyncio.create_task(check_for_stop_signal()) + + # Ensure active run key exists and has TTL + await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) + + # Initialize agent generator + agent_gen = run_agent( + thread_id=thread_id, project_id=project_id, stream=stream, + thread_manager=thread_manager, model_name=model_name, + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager + ) + + final_status = "running" + error_message = None + + async for response in agent_gen: + if stop_signal_received: + logger.info(f"Agent run {agent_run_id} stopped by signal.") + final_status = "stopped" + break + + # Store response in Redis list and publish notification + response_json = json.dumps(response) + await redis.rpush(response_list_key, response_json) + await redis.publish(response_channel, "new") + total_responses += 1 + + # Check for agent-signaled completion or error + if response.get('type') == 'status': + status_val = response.get('status') + if status_val in ['completed', 'failed', 'stopped']: + logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") + final_status = status_val + if status_val == 'failed' or status_val == 'stopped': + error_message = response.get('message', f"Run ended with status: {status_val}") + break + + # If loop finished without explicit completion/error/stop signal, mark as completed + if final_status == "running": + final_status = "completed" + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") + completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} + await redis.rpush(response_list_key, json.dumps(completion_message)) + await redis.publish(response_channel, "new") # Notify about the completion message + + # Fetch final responses from Redis for DB update + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + + # Update DB status + await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) + + # Publish final control signal (END_STREAM or ERROR) + control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" + try: + await redis.publish(global_control_channel, control_signal) + # No need to publish to instance channel as the run is ending on this instance + logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") + except Exception as e: + logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") + + except Exception as e: + error_message = str(e) + traceback_str = traceback.format_exc() + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") + final_status = "failed" + + # Push error message to Redis list + error_response = {"type": "status", "status": "error", "message": error_message} + try: + await redis.rpush(response_list_key, json.dumps(error_response)) + await redis.publish(response_channel, "new") + except Exception as redis_err: + logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") + + # Fetch final responses (including the error) + all_responses = [] + try: + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + except Exception as fetch_err: + logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") + all_responses = [error_response] # Use the error message we tried to push + + # Update DB status + await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) + + # Publish ERROR signal + try: + await redis.publish(global_control_channel, "ERROR") + logger.debug(f"Published ERROR signal to {global_control_channel}") + except Exception as e: + logger.warning(f"Failed to publish ERROR signal: {str(e)}") + + finally: + # Cleanup stop checker task + if stop_checker and not stop_checker.done(): + stop_checker.cancel() + try: await stop_checker + except asyncio.CancelledError: pass + except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") + + # Close pubsub connection + if pubsub: + try: + await pubsub.unsubscribe() + await pubsub.close() + logger.debug(f"Closed pubsub connection for {agent_run_id}") + except Exception as e: + logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") + + # Set TTL on the response list in Redis + await _cleanup_redis_response_list(agent_run_id) + + # Remove the instance-specific active run key + await _cleanup_redis_instance_key(agent_run_id) + + logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") + + +async def _cleanup_redis_instance_key(agent_run_id: str): + """Clean up the instance-specific Redis key for an agent run.""" + if not instance_id: + logger.warning("Instance ID not set, cannot clean up instance key.") + return + key = f"active_run:{instance_id}:{agent_run_id}" + logger.debug(f"Cleaning up Redis instance key: {key}") + try: + await redis.delete(key) + logger.debug(f"Successfully cleaned up Redis key: {key}") + except Exception as e: + logger.warning(f"Failed to clean up Redis key {key}: {str(e)}") + +# TTL for Redis response lists (24 hours) +REDIS_RESPONSE_LIST_TTL = 3600 * 24 + +async def _cleanup_redis_response_list(agent_run_id: str): + """Set TTL on the Redis response list.""" + response_list_key = f"agent_run:{agent_run_id}:responses" + try: + await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL) + logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}") + except Exception as e: + logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}") + +async def update_agent_run_status( + client, + agent_run_id: str, + status: str, + error: Optional[str] = None, + responses: Optional[list[any]] = None # Expects parsed list of dicts +) -> bool: + """ + Centralized function to update agent run status. + Returns True if update was successful. + """ + try: + update_data = { + "status": status, + "completed_at": datetime.now(timezone.utc).isoformat() + } + + if error: + update_data["error"] = error + + if responses: + # Ensure responses are stored correctly as JSONB + update_data["responses"] = responses + + # Retry up to 3 times + for retry in range(3): + try: + update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute() + + if hasattr(update_result, 'data') and update_result.data: + logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})") + + # Verify the update + verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute() + if verify_result.data: + actual_status = verify_result.data[0].get('status') + completed_at = verify_result.data[0].get('completed_at') + logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}") + return True + else: + logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}") + if retry == 2: # Last retry + logger.error(f"Failed to update agent run status after all retries: {agent_run_id}") + return False + except Exception as db_error: + logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}") + if retry < 2: # Not the last retry yet + await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff + else: + logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True) + return False + except Exception as e: + logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True) + return False + + return False