feat(dramatiq): added distributed worker support using dramatiq and rabbitmq

This commit is contained in:
sharath 2025-05-14 12:48:02 +00:00
parent fe05ccfb4f
commit 9b3561213d
No known key found for this signature in database
3 changed files with 481 additions and 176 deletions

View File

@ -21,6 +21,7 @@ from services.billing import check_billing_status
from utils.config import config from utils.config import config
from sandbox.sandbox import create_sandbox, get_or_start_sandbox from sandbox.sandbox import create_sandbox, get_or_start_sandbox
from services.llm import make_llm_api_call from services.llm import make_llm_api_call
from run_agent_background import run_agent_background
# Initialize shared resources # Initialize shared resources
router = APIRouter() router = APIRouter()
@ -428,19 +429,15 @@ async def start_agent(
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
# Run the agent in the background # Run the agent in the background
task = asyncio.create_task( run_agent_background.send(
run_agent_background(
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
project_id=project_id, sandbox=sandbox, project_id=project_id,
model_name=model_name, # Already resolved above model_name=model_name, # Already resolved above
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
stream=body.stream, enable_context_manager=body.enable_context_manager stream=body.stream, enable_context_manager=body.enable_context_manager
) )
)
# Set a callback to clean up Redis instance key when task is done # Set a callback to clean up Redis instance key when task is done
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
return {"agent_run_id": agent_run_id, "status": "running"} return {"agent_run_id": agent_run_id, "status": "running"}
@router.post("/agent-run/{agent_run_id}/stop") @router.post("/agent-run/{agent_run_id}/stop")
@ -661,186 +658,186 @@ async def stream_agent_run(
"Access-Control-Allow-Origin": "*" "Access-Control-Allow-Origin": "*"
}) })
async def run_agent_background( # @dramatiq.actor
agent_run_id: str, # async def run_agent_background(
thread_id: str, # agent_run_id: str,
instance_id: str, # Use the global instance ID passed during initialization # thread_id: str,
project_id: str, # instance_id: str, # Use the global instance ID passed during initialization
sandbox, # project_id: str,
model_name: str, # model_name: str,
enable_thinking: Optional[bool], # enable_thinking: Optional[bool],
reasoning_effort: Optional[str], # reasoning_effort: Optional[str],
stream: bool, # stream: bool,
enable_context_manager: bool # enable_context_manager: bool
): # ):
"""Run the agent in the background using Redis for state.""" # """Run the agent in the background using Redis for state."""
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") # logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") # logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
client = await db.client # client = await db.client
start_time = datetime.now(timezone.utc) # start_time = datetime.now(timezone.utc)
total_responses = 0 # total_responses = 0
pubsub = None # pubsub = None
stop_checker = None # stop_checker = None
stop_signal_received = False # stop_signal_received = False
# Define Redis keys and channels # # Define Redis keys and channels
response_list_key = f"agent_run:{agent_run_id}:responses" # response_list_key = f"agent_run:{agent_run_id}:responses"
response_channel = f"agent_run:{agent_run_id}:new_response" # response_channel = f"agent_run:{agent_run_id}:new_response"
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" # instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
global_control_channel = f"agent_run:{agent_run_id}:control" # global_control_channel = f"agent_run:{agent_run_id}:control"
instance_active_key = f"active_run:{instance_id}:{agent_run_id}" # instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
async def check_for_stop_signal(): # async def check_for_stop_signal():
nonlocal stop_signal_received # nonlocal stop_signal_received
if not pubsub: return # if not pubsub: return
try: # try:
while not stop_signal_received: # while not stop_signal_received:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) # message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
if message and message.get("type") == "message": # if message and message.get("type") == "message":
data = message.get("data") # data = message.get("data")
if isinstance(data, bytes): data = data.decode('utf-8') # if isinstance(data, bytes): data = data.decode('utf-8')
if data == "STOP": # if data == "STOP":
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") # logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
stop_signal_received = True # stop_signal_received = True
break # break
# Periodically refresh the active run key TTL # # Periodically refresh the active run key TTL
if total_responses % 50 == 0: # Refresh every 50 responses or so # if total_responses % 50 == 0: # Refresh every 50 responses or so
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) # try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") # except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
await asyncio.sleep(0.1) # Short sleep to prevent tight loop # await asyncio.sleep(0.1) # Short sleep to prevent tight loop
except asyncio.CancelledError: # except asyncio.CancelledError:
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") # logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
except Exception as e: # except Exception as e:
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) # logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
stop_signal_received = True # Stop the run if the checker fails # stop_signal_received = True # Stop the run if the checker fails
try: # try:
# Setup Pub/Sub listener for control signals # # Setup Pub/Sub listener for control signals
pubsub = await redis.create_pubsub() # pubsub = await redis.create_pubsub()
await pubsub.subscribe(instance_control_channel, global_control_channel) # await pubsub.subscribe(instance_control_channel, global_control_channel)
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") # logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
stop_checker = asyncio.create_task(check_for_stop_signal()) # stop_checker = asyncio.create_task(check_for_stop_signal())
# Ensure active run key exists and has TTL # # Ensure active run key exists and has TTL
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) # await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
# Initialize agent generator # # Initialize agent generator
agent_gen = run_agent( # agent_gen = run_agent(
thread_id=thread_id, project_id=project_id, stream=stream, # thread_id=thread_id, project_id=project_id, stream=stream,
thread_manager=thread_manager, model_name=model_name, # thread_manager=thread_manager, model_name=model_name,
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, # enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
enable_context_manager=enable_context_manager # enable_context_manager=enable_context_manager
) # )
final_status = "running" # final_status = "running"
error_message = None # error_message = None
async for response in agent_gen: # async for response in agent_gen:
if stop_signal_received: # if stop_signal_received:
logger.info(f"Agent run {agent_run_id} stopped by signal.") # logger.info(f"Agent run {agent_run_id} stopped by signal.")
final_status = "stopped" # final_status = "stopped"
break # break
# Store response in Redis list and publish notification # # Store response in Redis list and publish notification
response_json = json.dumps(response) # response_json = json.dumps(response)
await redis.rpush(response_list_key, response_json) # await redis.rpush(response_list_key, response_json)
await redis.publish(response_channel, "new") # await redis.publish(response_channel, "new")
total_responses += 1 # total_responses += 1
# Check for agent-signaled completion or error # # Check for agent-signaled completion or error
if response.get('type') == 'status': # if response.get('type') == 'status':
status_val = response.get('status') # status_val = response.get('status')
if status_val in ['completed', 'failed', 'stopped']: # if status_val in ['completed', 'failed', 'stopped']:
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") # logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
final_status = status_val # final_status = status_val
if status_val == 'failed' or status_val == 'stopped': # if status_val == 'failed' or status_val == 'stopped':
error_message = response.get('message', f"Run ended with status: {status_val}") # error_message = response.get('message', f"Run ended with status: {status_val}")
break # break
# If loop finished without explicit completion/error/stop signal, mark as completed # # If loop finished without explicit completion/error/stop signal, mark as completed
if final_status == "running": # if final_status == "running":
final_status = "completed" # final_status = "completed"
duration = (datetime.now(timezone.utc) - start_time).total_seconds() # duration = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") # logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} # completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
await redis.rpush(response_list_key, json.dumps(completion_message)) # await redis.rpush(response_list_key, json.dumps(completion_message))
await redis.publish(response_channel, "new") # Notify about the completion message # await redis.publish(response_channel, "new") # Notify about the completion message
# Fetch final responses from Redis for DB update # # Fetch final responses from Redis for DB update
all_responses_json = await redis.lrange(response_list_key, 0, -1) # all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json] # all_responses = [json.loads(r) for r in all_responses_json]
# Update DB status # # Update DB status
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) # await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
# Publish final control signal (END_STREAM or ERROR) # # Publish final control signal (END_STREAM or ERROR)
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" # control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
try: # try:
await redis.publish(global_control_channel, control_signal) # await redis.publish(global_control_channel, control_signal)
# No need to publish to instance channel as the run is ending on this instance # # No need to publish to instance channel as the run is ending on this instance
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") # logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
except Exception as e: # except Exception as e:
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") # logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
except Exception as e: # except Exception as e:
error_message = str(e) # error_message = str(e)
traceback_str = traceback.format_exc() # traceback_str = traceback.format_exc()
duration = (datetime.now(timezone.utc) - start_time).total_seconds() # duration = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") # logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
final_status = "failed" # final_status = "failed"
# Push error message to Redis list # # Push error message to Redis list
error_response = {"type": "status", "status": "error", "message": error_message} # error_response = {"type": "status", "status": "error", "message": error_message}
try: # try:
await redis.rpush(response_list_key, json.dumps(error_response)) # await redis.rpush(response_list_key, json.dumps(error_response))
await redis.publish(response_channel, "new") # await redis.publish(response_channel, "new")
except Exception as redis_err: # except Exception as redis_err:
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") # logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
# Fetch final responses (including the error) # # Fetch final responses (including the error)
all_responses = [] # all_responses = []
try: # try:
all_responses_json = await redis.lrange(response_list_key, 0, -1) # all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json] # all_responses = [json.loads(r) for r in all_responses_json]
except Exception as fetch_err: # except Exception as fetch_err:
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") # logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
all_responses = [error_response] # Use the error message we tried to push # all_responses = [error_response] # Use the error message we tried to push
# Update DB status # # Update DB status
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) # await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
# Publish ERROR signal # # Publish ERROR signal
try: # try:
await redis.publish(global_control_channel, "ERROR") # await redis.publish(global_control_channel, "ERROR")
logger.debug(f"Published ERROR signal to {global_control_channel}") # logger.debug(f"Published ERROR signal to {global_control_channel}")
except Exception as e: # except Exception as e:
logger.warning(f"Failed to publish ERROR signal: {str(e)}") # logger.warning(f"Failed to publish ERROR signal: {str(e)}")
finally: # finally:
# Cleanup stop checker task # # Cleanup stop checker task
if stop_checker and not stop_checker.done(): # if stop_checker and not stop_checker.done():
stop_checker.cancel() # stop_checker.cancel()
try: await stop_checker # try: await stop_checker
except asyncio.CancelledError: pass # except asyncio.CancelledError: pass
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") # except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
# Close pubsub connection # # Close pubsub connection
if pubsub: # if pubsub:
try: # try:
await pubsub.unsubscribe() # await pubsub.unsubscribe()
await pubsub.close() # await pubsub.close()
logger.debug(f"Closed pubsub connection for {agent_run_id}") # logger.debug(f"Closed pubsub connection for {agent_run_id}")
except Exception as e: # except Exception as e:
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") # logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
# Set TTL on the response list in Redis # # Set TTL on the response list in Redis
await _cleanup_redis_response_list(agent_run_id) # await _cleanup_redis_response_list(agent_run_id)
# Remove the instance-specific active run key # # Remove the instance-specific active run key
await _cleanup_redis_instance_key(agent_run_id) # await _cleanup_redis_instance_key(agent_run_id)
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") # logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
async def generate_and_update_project_name(project_id: str, prompt: str): async def generate_and_update_project_name(project_id: str, prompt: str):
"""Generates a project name using an LLM and updates the database.""" """Generates a project name using an LLM and updates the database."""
@ -1030,16 +1027,13 @@ async def initiate_agent_with_files(
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
# Run agent in background # Run agent in background
task = asyncio.create_task( run_agent_background.send(
run_agent_background(
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
project_id=project_id, sandbox=sandbox, project_id=project_id,
model_name=model_name, # Already resolved above model_name=model_name, # Already resolved above
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
stream=stream, enable_context_manager=enable_context_manager stream=stream, enable_context_manager=enable_context_manager
) )
)
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
return {"thread_id": thread_id, "agent_run_id": agent_run_id} return {"thread_id": thread_id, "agent_run_id": agent_run_id}

View File

@ -73,6 +73,11 @@ services:
cpus: '1' cpus: '1'
memory: 8G memory: 8G
rabbitmq:
image: rabbitmq
ports:
- "127.0.0.1:5672:5672"
networks: networks:
app-network: app-network:
driver: bridge driver: bridge

View File

@ -0,0 +1,306 @@
import asyncio
import json
import traceback
from datetime import datetime, timezone
from typing import Optional
from services import redis
from agent.run import run_agent
from utils.logger import logger
import dramatiq
import uuid
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from services import redis
from dramatiq.brokers.rabbitmq import RabbitmqBroker
rabbitmq_broker = RabbitmqBroker(host="localhost", port="5672", middleware=[dramatiq.middleware.AsyncIO()])
dramatiq.set_broker(rabbitmq_broker)
_initialized = False
db = DBConnection()
thread_manager = None
instance_id = "single"
async def initialize():
"""Initialize the agent API with resources from the main API."""
global thread_manager, db, instance_id, _initialized
if _initialized:
return
# Use provided instance_id or generate a new one
if not instance_id:
# Generate instance ID
instance_id = str(uuid.uuid4())[:8]
await redis.initialize_async()
await db.initialize()
thread_manager = ThreadManager()
_initialized = True
logger.info(f"Initialized agent API with instance ID: {instance_id}")
@dramatiq.actor
async def run_agent_background(
agent_run_id: str,
thread_id: str,
instance_id: str, # Use the global instance ID passed during initialization
project_id: str,
model_name: str,
enable_thinking: Optional[bool],
reasoning_effort: Optional[str],
stream: bool,
enable_context_manager: bool
):
"""Run the agent in the background using Redis for state."""
await initialize()
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
client = await db.client
start_time = datetime.now(timezone.utc)
total_responses = 0
pubsub = None
stop_checker = None
stop_signal_received = False
# Define Redis keys and channels
response_list_key = f"agent_run:{agent_run_id}:responses"
response_channel = f"agent_run:{agent_run_id}:new_response"
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
global_control_channel = f"agent_run:{agent_run_id}:control"
instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
async def check_for_stop_signal():
nonlocal stop_signal_received
if not pubsub: return
try:
while not stop_signal_received:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
if message and message.get("type") == "message":
data = message.get("data")
if isinstance(data, bytes): data = data.decode('utf-8')
if data == "STOP":
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
stop_signal_received = True
break
# Periodically refresh the active run key TTL
if total_responses % 50 == 0: # Refresh every 50 responses or so
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
await asyncio.sleep(0.1) # Short sleep to prevent tight loop
except asyncio.CancelledError:
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
except Exception as e:
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
stop_signal_received = True # Stop the run if the checker fails
try:
# Setup Pub/Sub listener for control signals
pubsub = await redis.create_pubsub()
await pubsub.subscribe(instance_control_channel, global_control_channel)
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
stop_checker = asyncio.create_task(check_for_stop_signal())
# Ensure active run key exists and has TTL
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
# Initialize agent generator
agent_gen = run_agent(
thread_id=thread_id, project_id=project_id, stream=stream,
thread_manager=thread_manager, model_name=model_name,
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
enable_context_manager=enable_context_manager
)
final_status = "running"
error_message = None
async for response in agent_gen:
if stop_signal_received:
logger.info(f"Agent run {agent_run_id} stopped by signal.")
final_status = "stopped"
break
# Store response in Redis list and publish notification
response_json = json.dumps(response)
await redis.rpush(response_list_key, response_json)
await redis.publish(response_channel, "new")
total_responses += 1
# Check for agent-signaled completion or error
if response.get('type') == 'status':
status_val = response.get('status')
if status_val in ['completed', 'failed', 'stopped']:
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
final_status = status_val
if status_val == 'failed' or status_val == 'stopped':
error_message = response.get('message', f"Run ended with status: {status_val}")
break
# If loop finished without explicit completion/error/stop signal, mark as completed
if final_status == "running":
final_status = "completed"
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
await redis.rpush(response_list_key, json.dumps(completion_message))
await redis.publish(response_channel, "new") # Notify about the completion message
# Fetch final responses from Redis for DB update
all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json]
# Update DB status
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
# Publish final control signal (END_STREAM or ERROR)
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
try:
await redis.publish(global_control_channel, control_signal)
# No need to publish to instance channel as the run is ending on this instance
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
except Exception as e:
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
except Exception as e:
error_message = str(e)
traceback_str = traceback.format_exc()
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
final_status = "failed"
# Push error message to Redis list
error_response = {"type": "status", "status": "error", "message": error_message}
try:
await redis.rpush(response_list_key, json.dumps(error_response))
await redis.publish(response_channel, "new")
except Exception as redis_err:
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
# Fetch final responses (including the error)
all_responses = []
try:
all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json]
except Exception as fetch_err:
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
all_responses = [error_response] # Use the error message we tried to push
# Update DB status
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
# Publish ERROR signal
try:
await redis.publish(global_control_channel, "ERROR")
logger.debug(f"Published ERROR signal to {global_control_channel}")
except Exception as e:
logger.warning(f"Failed to publish ERROR signal: {str(e)}")
finally:
# Cleanup stop checker task
if stop_checker and not stop_checker.done():
stop_checker.cancel()
try: await stop_checker
except asyncio.CancelledError: pass
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
# Close pubsub connection
if pubsub:
try:
await pubsub.unsubscribe()
await pubsub.close()
logger.debug(f"Closed pubsub connection for {agent_run_id}")
except Exception as e:
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
# Set TTL on the response list in Redis
await _cleanup_redis_response_list(agent_run_id)
# Remove the instance-specific active run key
await _cleanup_redis_instance_key(agent_run_id)
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
async def _cleanup_redis_instance_key(agent_run_id: str):
"""Clean up the instance-specific Redis key for an agent run."""
if not instance_id:
logger.warning("Instance ID not set, cannot clean up instance key.")
return
key = f"active_run:{instance_id}:{agent_run_id}"
logger.debug(f"Cleaning up Redis instance key: {key}")
try:
await redis.delete(key)
logger.debug(f"Successfully cleaned up Redis key: {key}")
except Exception as e:
logger.warning(f"Failed to clean up Redis key {key}: {str(e)}")
# TTL for Redis response lists (24 hours)
REDIS_RESPONSE_LIST_TTL = 3600 * 24
async def _cleanup_redis_response_list(agent_run_id: str):
"""Set TTL on the Redis response list."""
response_list_key = f"agent_run:{agent_run_id}:responses"
try:
await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL)
logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}")
except Exception as e:
logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}")
async def update_agent_run_status(
client,
agent_run_id: str,
status: str,
error: Optional[str] = None,
responses: Optional[list[any]] = None # Expects parsed list of dicts
) -> bool:
"""
Centralized function to update agent run status.
Returns True if update was successful.
"""
try:
update_data = {
"status": status,
"completed_at": datetime.now(timezone.utc).isoformat()
}
if error:
update_data["error"] = error
if responses:
# Ensure responses are stored correctly as JSONB
update_data["responses"] = responses
# Retry up to 3 times
for retry in range(3):
try:
update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute()
if hasattr(update_result, 'data') and update_result.data:
logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})")
# Verify the update
verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute()
if verify_result.data:
actual_status = verify_result.data[0].get('status')
completed_at = verify_result.data[0].get('completed_at')
logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}")
return True
else:
logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}")
if retry == 2: # Last retry
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
return False
except Exception as db_error:
logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}")
if retry < 2: # Not the last retry yet
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
else:
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
return False
except Exception as e:
logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True)
return False
return False