From 9b3561213d098a0a85296803f4ecad117b1ec8ee Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Wed, 14 May 2025 12:48:02 +0000 Subject: [PATCH 1/4] feat(dramatiq): added distributed worker support using dramatiq and rabbitmq --- backend/agent/api.py | 344 ++++++++++++++++---------------- backend/docker-compose.yml | 7 +- backend/run_agent_background.py | 306 ++++++++++++++++++++++++++++ 3 files changed, 481 insertions(+), 176 deletions(-) create mode 100644 backend/run_agent_background.py diff --git a/backend/agent/api.py b/backend/agent/api.py index 618b2b0d..feb42c01 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -21,6 +21,7 @@ from services.billing import check_billing_status from utils.config import config from sandbox.sandbox import create_sandbox, get_or_start_sandbox from services.llm import make_llm_api_call +from run_agent_background import run_agent_background # Initialize shared resources router = APIRouter() @@ -428,19 +429,15 @@ async def start_agent( logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") # Run the agent in the background - task = asyncio.create_task( - run_agent_background( - agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, - project_id=project_id, sandbox=sandbox, - model_name=model_name, # Already resolved above - enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, - stream=body.stream, enable_context_manager=body.enable_context_manager - ) + run_agent_background.send( + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, + model_name=model_name, # Already resolved above + enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, + stream=body.stream, enable_context_manager=body.enable_context_manager ) # Set a callback to clean up Redis instance key when task is done - task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) - return {"agent_run_id": agent_run_id, "status": "running"} @router.post("/agent-run/{agent_run_id}/stop") @@ -661,186 +658,186 @@ async def stream_agent_run( "Access-Control-Allow-Origin": "*" }) -async def run_agent_background( - agent_run_id: str, - thread_id: str, - instance_id: str, # Use the global instance ID passed during initialization - project_id: str, - sandbox, - model_name: str, - enable_thinking: Optional[bool], - reasoning_effort: Optional[str], - stream: bool, - enable_context_manager: bool -): - """Run the agent in the background using Redis for state.""" - logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") - logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") +# @dramatiq.actor +# async def run_agent_background( +# agent_run_id: str, +# thread_id: str, +# instance_id: str, # Use the global instance ID passed during initialization +# project_id: str, +# model_name: str, +# enable_thinking: Optional[bool], +# reasoning_effort: Optional[str], +# stream: bool, +# enable_context_manager: bool +# ): +# """Run the agent in the background using Redis for state.""" +# logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") +# logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") - client = await db.client - start_time = datetime.now(timezone.utc) - total_responses = 0 - pubsub = None - stop_checker = None - stop_signal_received = False +# client = await db.client +# start_time = datetime.now(timezone.utc) +# total_responses = 0 +# pubsub = None +# stop_checker = None +# stop_signal_received = False - # Define Redis keys and channels - response_list_key = f"agent_run:{agent_run_id}:responses" - response_channel = f"agent_run:{agent_run_id}:new_response" - instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" - global_control_channel = f"agent_run:{agent_run_id}:control" - instance_active_key = f"active_run:{instance_id}:{agent_run_id}" +# # Define Redis keys and channels +# response_list_key = f"agent_run:{agent_run_id}:responses" +# response_channel = f"agent_run:{agent_run_id}:new_response" +# instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" +# global_control_channel = f"agent_run:{agent_run_id}:control" +# instance_active_key = f"active_run:{instance_id}:{agent_run_id}" - async def check_for_stop_signal(): - nonlocal stop_signal_received - if not pubsub: return - try: - while not stop_signal_received: - message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) - if message and message.get("type") == "message": - data = message.get("data") - if isinstance(data, bytes): data = data.decode('utf-8') - if data == "STOP": - logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") - stop_signal_received = True - break - # Periodically refresh the active run key TTL - if total_responses % 50 == 0: # Refresh every 50 responses or so - try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) - except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") - await asyncio.sleep(0.1) # Short sleep to prevent tight loop - except asyncio.CancelledError: - logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") - except Exception as e: - logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) - stop_signal_received = True # Stop the run if the checker fails +# async def check_for_stop_signal(): +# nonlocal stop_signal_received +# if not pubsub: return +# try: +# while not stop_signal_received: +# message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) +# if message and message.get("type") == "message": +# data = message.get("data") +# if isinstance(data, bytes): data = data.decode('utf-8') +# if data == "STOP": +# logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") +# stop_signal_received = True +# break +# # Periodically refresh the active run key TTL +# if total_responses % 50 == 0: # Refresh every 50 responses or so +# try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) +# except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") +# await asyncio.sleep(0.1) # Short sleep to prevent tight loop +# except asyncio.CancelledError: +# logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") +# except Exception as e: +# logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) +# stop_signal_received = True # Stop the run if the checker fails - try: - # Setup Pub/Sub listener for control signals - pubsub = await redis.create_pubsub() - await pubsub.subscribe(instance_control_channel, global_control_channel) - logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") - stop_checker = asyncio.create_task(check_for_stop_signal()) +# try: +# # Setup Pub/Sub listener for control signals +# pubsub = await redis.create_pubsub() +# await pubsub.subscribe(instance_control_channel, global_control_channel) +# logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") +# stop_checker = asyncio.create_task(check_for_stop_signal()) - # Ensure active run key exists and has TTL - await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) +# # Ensure active run key exists and has TTL +# await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) - # Initialize agent generator - agent_gen = run_agent( - thread_id=thread_id, project_id=project_id, stream=stream, - thread_manager=thread_manager, model_name=model_name, - enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, - enable_context_manager=enable_context_manager - ) +# # Initialize agent generator +# agent_gen = run_agent( +# thread_id=thread_id, project_id=project_id, stream=stream, +# thread_manager=thread_manager, model_name=model_name, +# enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, +# enable_context_manager=enable_context_manager +# ) - final_status = "running" - error_message = None +# final_status = "running" +# error_message = None - async for response in agent_gen: - if stop_signal_received: - logger.info(f"Agent run {agent_run_id} stopped by signal.") - final_status = "stopped" - break +# async for response in agent_gen: +# if stop_signal_received: +# logger.info(f"Agent run {agent_run_id} stopped by signal.") +# final_status = "stopped" +# break - # Store response in Redis list and publish notification - response_json = json.dumps(response) - await redis.rpush(response_list_key, response_json) - await redis.publish(response_channel, "new") - total_responses += 1 +# # Store response in Redis list and publish notification +# response_json = json.dumps(response) +# await redis.rpush(response_list_key, response_json) +# await redis.publish(response_channel, "new") +# total_responses += 1 - # Check for agent-signaled completion or error - if response.get('type') == 'status': - status_val = response.get('status') - if status_val in ['completed', 'failed', 'stopped']: - logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") - final_status = status_val - if status_val == 'failed' or status_val == 'stopped': - error_message = response.get('message', f"Run ended with status: {status_val}") - break +# # Check for agent-signaled completion or error +# if response.get('type') == 'status': +# status_val = response.get('status') +# if status_val in ['completed', 'failed', 'stopped']: +# logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") +# final_status = status_val +# if status_val == 'failed' or status_val == 'stopped': +# error_message = response.get('message', f"Run ended with status: {status_val}") +# break - # If loop finished without explicit completion/error/stop signal, mark as completed - if final_status == "running": - final_status = "completed" - duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") - completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} - await redis.rpush(response_list_key, json.dumps(completion_message)) - await redis.publish(response_channel, "new") # Notify about the completion message +# # If loop finished without explicit completion/error/stop signal, mark as completed +# if final_status == "running": +# final_status = "completed" +# duration = (datetime.now(timezone.utc) - start_time).total_seconds() +# logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") +# completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} +# await redis.rpush(response_list_key, json.dumps(completion_message)) +# await redis.publish(response_channel, "new") # Notify about the completion message - # Fetch final responses from Redis for DB update - all_responses_json = await redis.lrange(response_list_key, 0, -1) - all_responses = [json.loads(r) for r in all_responses_json] +# # Fetch final responses from Redis for DB update +# all_responses_json = await redis.lrange(response_list_key, 0, -1) +# all_responses = [json.loads(r) for r in all_responses_json] - # Update DB status - await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) +# # Update DB status +# await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) - # Publish final control signal (END_STREAM or ERROR) - control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" - try: - await redis.publish(global_control_channel, control_signal) - # No need to publish to instance channel as the run is ending on this instance - logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") - except Exception as e: - logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") +# # Publish final control signal (END_STREAM or ERROR) +# control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" +# try: +# await redis.publish(global_control_channel, control_signal) +# # No need to publish to instance channel as the run is ending on this instance +# logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") +# except Exception as e: +# logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") - except Exception as e: - error_message = str(e) - traceback_str = traceback.format_exc() - duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") - final_status = "failed" +# except Exception as e: +# error_message = str(e) +# traceback_str = traceback.format_exc() +# duration = (datetime.now(timezone.utc) - start_time).total_seconds() +# logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") +# final_status = "failed" - # Push error message to Redis list - error_response = {"type": "status", "status": "error", "message": error_message} - try: - await redis.rpush(response_list_key, json.dumps(error_response)) - await redis.publish(response_channel, "new") - except Exception as redis_err: - logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") +# # Push error message to Redis list +# error_response = {"type": "status", "status": "error", "message": error_message} +# try: +# await redis.rpush(response_list_key, json.dumps(error_response)) +# await redis.publish(response_channel, "new") +# except Exception as redis_err: +# logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") - # Fetch final responses (including the error) - all_responses = [] - try: - all_responses_json = await redis.lrange(response_list_key, 0, -1) - all_responses = [json.loads(r) for r in all_responses_json] - except Exception as fetch_err: - logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") - all_responses = [error_response] # Use the error message we tried to push +# # Fetch final responses (including the error) +# all_responses = [] +# try: +# all_responses_json = await redis.lrange(response_list_key, 0, -1) +# all_responses = [json.loads(r) for r in all_responses_json] +# except Exception as fetch_err: +# logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") +# all_responses = [error_response] # Use the error message we tried to push - # Update DB status - await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) +# # Update DB status +# await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) - # Publish ERROR signal - try: - await redis.publish(global_control_channel, "ERROR") - logger.debug(f"Published ERROR signal to {global_control_channel}") - except Exception as e: - logger.warning(f"Failed to publish ERROR signal: {str(e)}") +# # Publish ERROR signal +# try: +# await redis.publish(global_control_channel, "ERROR") +# logger.debug(f"Published ERROR signal to {global_control_channel}") +# except Exception as e: +# logger.warning(f"Failed to publish ERROR signal: {str(e)}") - finally: - # Cleanup stop checker task - if stop_checker and not stop_checker.done(): - stop_checker.cancel() - try: await stop_checker - except asyncio.CancelledError: pass - except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") +# finally: +# # Cleanup stop checker task +# if stop_checker and not stop_checker.done(): +# stop_checker.cancel() +# try: await stop_checker +# except asyncio.CancelledError: pass +# except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") - # Close pubsub connection - if pubsub: - try: - await pubsub.unsubscribe() - await pubsub.close() - logger.debug(f"Closed pubsub connection for {agent_run_id}") - except Exception as e: - logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") +# # Close pubsub connection +# if pubsub: +# try: +# await pubsub.unsubscribe() +# await pubsub.close() +# logger.debug(f"Closed pubsub connection for {agent_run_id}") +# except Exception as e: +# logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") - # Set TTL on the response list in Redis - await _cleanup_redis_response_list(agent_run_id) +# # Set TTL on the response list in Redis +# await _cleanup_redis_response_list(agent_run_id) - # Remove the instance-specific active run key - await _cleanup_redis_instance_key(agent_run_id) +# # Remove the instance-specific active run key +# await _cleanup_redis_instance_key(agent_run_id) - logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") +# logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") async def generate_and_update_project_name(project_id: str, prompt: str): """Generates a project name using an LLM and updates the database.""" @@ -1030,16 +1027,13 @@ async def initiate_agent_with_files( logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") # Run agent in background - task = asyncio.create_task( - run_agent_background( - agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, - project_id=project_id, sandbox=sandbox, - model_name=model_name, # Already resolved above - enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, - stream=stream, enable_context_manager=enable_context_manager - ) + run_agent_background.send( + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, + model_name=model_name, # Already resolved above + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, + stream=stream, enable_context_manager=enable_context_manager ) - task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) return {"thread_id": thread_id, "agent_run_id": agent_run_id} diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index 4220dc57..ff409227 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -73,9 +73,14 @@ services: cpus: '1' memory: 8G + rabbitmq: + image: rabbitmq + ports: + - "127.0.0.1:5672:5672" + networks: app-network: driver: bridge volumes: - redis_data: \ No newline at end of file + redis_data: diff --git a/backend/run_agent_background.py b/backend/run_agent_background.py new file mode 100644 index 00000000..0c774b2b --- /dev/null +++ b/backend/run_agent_background.py @@ -0,0 +1,306 @@ +import asyncio +import json +import traceback +from datetime import datetime, timezone +from typing import Optional +from services import redis +from agent.run import run_agent +from utils.logger import logger +import dramatiq +import uuid +from agentpress.thread_manager import ThreadManager +from services.supabase import DBConnection +from services import redis +from dramatiq.brokers.rabbitmq import RabbitmqBroker + +rabbitmq_broker = RabbitmqBroker(host="localhost", port="5672", middleware=[dramatiq.middleware.AsyncIO()]) +dramatiq.set_broker(rabbitmq_broker) + +_initialized = False +db = DBConnection() +thread_manager = None +instance_id = "single" + +async def initialize(): + """Initialize the agent API with resources from the main API.""" + global thread_manager, db, instance_id, _initialized + if _initialized: + return + + # Use provided instance_id or generate a new one + if not instance_id: + # Generate instance ID + instance_id = str(uuid.uuid4())[:8] + await redis.initialize_async() + await db.initialize() + thread_manager = ThreadManager() + + _initialized = True + logger.info(f"Initialized agent API with instance ID: {instance_id}") + + +@dramatiq.actor +async def run_agent_background( + agent_run_id: str, + thread_id: str, + instance_id: str, # Use the global instance ID passed during initialization + project_id: str, + model_name: str, + enable_thinking: Optional[bool], + reasoning_effort: Optional[str], + stream: bool, + enable_context_manager: bool +): + """Run the agent in the background using Redis for state.""" + await initialize() + + logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") + logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") + + client = await db.client + start_time = datetime.now(timezone.utc) + total_responses = 0 + pubsub = None + stop_checker = None + stop_signal_received = False + + # Define Redis keys and channels + response_list_key = f"agent_run:{agent_run_id}:responses" + response_channel = f"agent_run:{agent_run_id}:new_response" + instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" + global_control_channel = f"agent_run:{agent_run_id}:control" + instance_active_key = f"active_run:{instance_id}:{agent_run_id}" + + async def check_for_stop_signal(): + nonlocal stop_signal_received + if not pubsub: return + try: + while not stop_signal_received: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) + if message and message.get("type") == "message": + data = message.get("data") + if isinstance(data, bytes): data = data.decode('utf-8') + if data == "STOP": + logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") + stop_signal_received = True + break + # Periodically refresh the active run key TTL + if total_responses % 50 == 0: # Refresh every 50 responses or so + try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) + except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") + await asyncio.sleep(0.1) # Short sleep to prevent tight loop + except asyncio.CancelledError: + logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") + except Exception as e: + logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) + stop_signal_received = True # Stop the run if the checker fails + + try: + # Setup Pub/Sub listener for control signals + pubsub = await redis.create_pubsub() + await pubsub.subscribe(instance_control_channel, global_control_channel) + logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") + stop_checker = asyncio.create_task(check_for_stop_signal()) + + # Ensure active run key exists and has TTL + await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) + + # Initialize agent generator + agent_gen = run_agent( + thread_id=thread_id, project_id=project_id, stream=stream, + thread_manager=thread_manager, model_name=model_name, + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager + ) + + final_status = "running" + error_message = None + + async for response in agent_gen: + if stop_signal_received: + logger.info(f"Agent run {agent_run_id} stopped by signal.") + final_status = "stopped" + break + + # Store response in Redis list and publish notification + response_json = json.dumps(response) + await redis.rpush(response_list_key, response_json) + await redis.publish(response_channel, "new") + total_responses += 1 + + # Check for agent-signaled completion or error + if response.get('type') == 'status': + status_val = response.get('status') + if status_val in ['completed', 'failed', 'stopped']: + logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") + final_status = status_val + if status_val == 'failed' or status_val == 'stopped': + error_message = response.get('message', f"Run ended with status: {status_val}") + break + + # If loop finished without explicit completion/error/stop signal, mark as completed + if final_status == "running": + final_status = "completed" + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") + completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} + await redis.rpush(response_list_key, json.dumps(completion_message)) + await redis.publish(response_channel, "new") # Notify about the completion message + + # Fetch final responses from Redis for DB update + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + + # Update DB status + await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) + + # Publish final control signal (END_STREAM or ERROR) + control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" + try: + await redis.publish(global_control_channel, control_signal) + # No need to publish to instance channel as the run is ending on this instance + logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") + except Exception as e: + logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") + + except Exception as e: + error_message = str(e) + traceback_str = traceback.format_exc() + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") + final_status = "failed" + + # Push error message to Redis list + error_response = {"type": "status", "status": "error", "message": error_message} + try: + await redis.rpush(response_list_key, json.dumps(error_response)) + await redis.publish(response_channel, "new") + except Exception as redis_err: + logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") + + # Fetch final responses (including the error) + all_responses = [] + try: + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + except Exception as fetch_err: + logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") + all_responses = [error_response] # Use the error message we tried to push + + # Update DB status + await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) + + # Publish ERROR signal + try: + await redis.publish(global_control_channel, "ERROR") + logger.debug(f"Published ERROR signal to {global_control_channel}") + except Exception as e: + logger.warning(f"Failed to publish ERROR signal: {str(e)}") + + finally: + # Cleanup stop checker task + if stop_checker and not stop_checker.done(): + stop_checker.cancel() + try: await stop_checker + except asyncio.CancelledError: pass + except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") + + # Close pubsub connection + if pubsub: + try: + await pubsub.unsubscribe() + await pubsub.close() + logger.debug(f"Closed pubsub connection for {agent_run_id}") + except Exception as e: + logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") + + # Set TTL on the response list in Redis + await _cleanup_redis_response_list(agent_run_id) + + # Remove the instance-specific active run key + await _cleanup_redis_instance_key(agent_run_id) + + logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") + + +async def _cleanup_redis_instance_key(agent_run_id: str): + """Clean up the instance-specific Redis key for an agent run.""" + if not instance_id: + logger.warning("Instance ID not set, cannot clean up instance key.") + return + key = f"active_run:{instance_id}:{agent_run_id}" + logger.debug(f"Cleaning up Redis instance key: {key}") + try: + await redis.delete(key) + logger.debug(f"Successfully cleaned up Redis key: {key}") + except Exception as e: + logger.warning(f"Failed to clean up Redis key {key}: {str(e)}") + +# TTL for Redis response lists (24 hours) +REDIS_RESPONSE_LIST_TTL = 3600 * 24 + +async def _cleanup_redis_response_list(agent_run_id: str): + """Set TTL on the Redis response list.""" + response_list_key = f"agent_run:{agent_run_id}:responses" + try: + await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL) + logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}") + except Exception as e: + logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}") + +async def update_agent_run_status( + client, + agent_run_id: str, + status: str, + error: Optional[str] = None, + responses: Optional[list[any]] = None # Expects parsed list of dicts +) -> bool: + """ + Centralized function to update agent run status. + Returns True if update was successful. + """ + try: + update_data = { + "status": status, + "completed_at": datetime.now(timezone.utc).isoformat() + } + + if error: + update_data["error"] = error + + if responses: + # Ensure responses are stored correctly as JSONB + update_data["responses"] = responses + + # Retry up to 3 times + for retry in range(3): + try: + update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute() + + if hasattr(update_result, 'data') and update_result.data: + logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})") + + # Verify the update + verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute() + if verify_result.data: + actual_status = verify_result.data[0].get('status') + completed_at = verify_result.data[0].get('completed_at') + logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}") + return True + else: + logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}") + if retry == 2: # Last retry + logger.error(f"Failed to update agent run status after all retries: {agent_run_id}") + return False + except Exception as db_error: + logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}") + if retry < 2: # Not the last retry yet + await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff + else: + logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True) + return False + except Exception as e: + logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True) + return False + + return False From 3a4debd407643c78c2c0274831777a1b81f0ef74 Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Thu, 15 May 2025 06:29:27 +0000 Subject: [PATCH 2/4] feat(dramatiq): add workers to docker compose and update docs --- README.md | 20 +++++- backend/.env.example | 3 + backend/README.md | 25 ++++--- backend/docker-compose.yml | 117 ++++++++++++++++++++++++++++++-- backend/requirements.txt | 3 +- backend/run_agent_background.py | 5 +- 6 files changed, 155 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 6d33345e..8924310e 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ Suna can be self-hosted on your own infrastructure. Follow these steps to set up You'll need the following components: - A Supabase project for database and authentication - Redis database for caching and session management +- RabbitMQ message queue for orchestrating worker tasks - Daytona sandbox for secure agent execution - Python 3.11 for the API backend - API keys for LLM providers (Anthropic, OpenRouter) @@ -99,9 +100,9 @@ You'll need the following components: - Save your project's API URL, anon key, and service role key for later use - Install the [Supabase CLI](https://supabase.com/docs/guides/cli/getting-started) -2. **Redis**: +2. **Redis and RabbitMQ**: - Go to the `/backend` folder - - Run `docker compose up redis` + - Run `docker compose up redis rabbitmq` 3. **Daytona**: - Create an account on [Daytona](https://app.daytona.io/) @@ -157,6 +158,9 @@ REDIS_PORT=6379 REDIS_PASSWORD=your_redis_password REDIS_SSL=True # Set to False for local Redis without SSL +RABBITMQ_HOST=your_rabbitmq_host # Set to localhost if running locally +RABBITMQ_PORT=5672 + # Daytona credentials from step 3 DAYTONA_API_KEY=your_daytona_api_key DAYTONA_SERVER_URL="https://app.daytona.io/api" @@ -230,6 +234,12 @@ npm run dev ```bash cd backend poetry run python3.11 api.py +``` + + In one more terminal, start the backend worker: +```bash +cd backend +poetry run python3.11 -m dramatiq run_agent_background ``` 5-6. **Docker Compose Alternative**: @@ -237,12 +247,16 @@ poetry run python3.11 api.py Before running with Docker Compose, make sure your environment files are properly configured: - In `backend/.env`, set all the required environment variables as described above - For Redis configuration, use `REDIS_HOST=redis` instead of localhost + - For RabbitMQ, use `RABBITMQ_HOST=rabbitmq` instead of localhost - The Docker Compose setup will automatically set these Redis environment variables: ``` REDIS_HOST=redis REDIS_PORT=6379 REDIS_PASSWORD= REDIS_SSL=False + + RABBITMQ_HOST=rabbitmq + RABBITMQ_PORT=5672 ``` - In `frontend/.env.local`, make sure to set `NEXT_PUBLIC_BACKEND_URL="http://backend:8000/api"` to use the container name @@ -257,7 +271,7 @@ If you're building the images locally instead of using pre-built ones: docker compose up ``` -The Docker Compose setup includes a Redis service that will be used by the backend automatically. +The Docker Compose setup includes Redis and RabbitMQ services that will be used by the backend automatically. 7. **Access Suna**: diff --git a/backend/.env.example b/backend/.env.example index 2961f80c..4588ab77 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -14,6 +14,9 @@ REDIS_PORT=6379 REDIS_PASSWORD= REDIS_SSL=false +RABBITMQ_HOST=rabbitmq +RABBITMQ_PORT=5672 + # LLM Providers: ANTHROPIC_API_KEY= OPENAI_API_KEY= diff --git a/backend/README.md b/backend/README.md index b7253570..c16a94fb 100644 --- a/backend/README.md +++ b/backend/README.md @@ -11,25 +11,25 @@ docker compose down && docker compose up --build You can run individual services from the docker-compose file. This is particularly useful during development: -### Running only Redis +### Running only Redis and RabbitMQ ```bash -docker compose up redis +docker compose up redis rabbitmq ``` -### Running only the API +### Running only the API and Worker ```bash -docker compose up api +docker compose up api worker ``` ## Development Setup -For local development, you might only need to run Redis while working on the API locally. This is useful when: +For local development, you might only need to run Redis and RabbitMQ, while working on the API locally. This is useful when: - You're making changes to the API code and want to test them directly - You want to avoid rebuilding the API container on every change - You're running the API service directly on your machine -To run just Redis for development:```bash -docker compose up redis +To run just Redis and RabbitMQ for development:```bash +docker compose up redis rabbitmq ``` Then you can run your API service locally with your preferred method (e.g., poetry run python3.11 api.py). @@ -38,16 +38,25 @@ Then you can run your API service locally with your preferred method (e.g., poet When running services individually, make sure to: 1. Check your `.env` file and adjust any necessary environment variables 2. Ensure Redis connection settings match your local setup (default: `localhost:6379`) -3. Update any service-specific environment variables if needed +3. Ensure RabbitMQ connection settings match your local setup (default: `localhost:5672`) +4. Update any service-specific environment variables if needed ### Important: Redis Host Configuration When running the API locally with Redis in Docker, you need to set the correct Redis host in your `.env` file: - For Docker-to-Docker communication (when running both services in Docker): use `REDIS_HOST=redis` - For local-to-Docker communication (when running API locally): use `REDIS_HOST=localhost` +### Important: RabbitMQ Host Configuration +When running the API locally with Redis in Docker, you need to set the correct RabbitMQ host in your `.env` file: +- For Docker-to-Docker communication (when running both services in Docker): use `RABBITMQ_HOST=rabbitmq` +- For local-to-Docker communication (when running API locally): use `RABBITMQ_HOST=localhost` + Example `.env` configuration for local development: ```env REDIS_HOST=localhost (instead of 'redis') REDIS_PORT=6379 REDIS_PASSWORD= + +RABBITMQ_HOST=localhost (instead of 'rabbitmq') +RABBITMQ_PORT=5672 ``` diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index ff409227..af115156 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.8' +version: "3.8" services: api: @@ -16,6 +16,8 @@ services: depends_on: redis: condition: service_healthy + rabbitmq: + condition: service_healthy networks: - app-network environment: @@ -23,6 +25,8 @@ services: - REDIS_PORT=6379 - REDIS_PASSWORD= - LOG_LEVEL=INFO + - RABBITMQ_HOST=rabbitmq + - RABBITMQ_PORT=5672 logging: driver: "json-file" options: @@ -31,10 +35,10 @@ services: deploy: resources: limits: - cpus: '14' + cpus: "14" memory: 48G reservations: - cpus: '8' + cpus: "8" memory: 32G healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/api/health"] @@ -43,6 +47,84 @@ services: retries: 3 start_period: 40s + worker-1: + build: + context: . + dockerfile: Dockerfile + command: python -m dramatiq run_agent_background + env_file: + - .env + volumes: + - .:/app + - ./worker-1-logs:/app/logs + restart: unless-stopped + depends_on: + redis: + condition: service_healthy + rabbitmq: + condition: service_healthy + networks: + - app-network + environment: + - REDIS_HOST=redis + - REDIS_PORT=6379 + - REDIS_PASSWORD= + - LOG_LEVEL=INFO + - RABBITMQ_HOST=rabbitmq + - RABBITMQ_PORT=5672 + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + deploy: + resources: + limits: + cpus: "14" + memory: 48G + reservations: + cpus: "8" + memory: 32G + + worker-2: + build: + context: . + dockerfile: Dockerfile + command: python -m dramatiq run_agent_background + env_file: + - .env + volumes: + - .:/app + - ./worker-2-logs:/app/logs + restart: unless-stopped + depends_on: + redis: + condition: service_healthy + rabbitmq: + condition: service_healthy + networks: + - app-network + environment: + - REDIS_HOST=redis + - REDIS_PORT=6379 + - REDIS_PASSWORD= + - LOG_LEVEL=INFO + - RABBITMQ_HOST=rabbitmq + - RABBITMQ_PORT=5672 + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + deploy: + resources: + limits: + cpus: "14" + memory: 48G + reservations: + cpus: "8" + memory: 32G + redis: image: redis:7-alpine ports: @@ -67,16 +149,40 @@ services: deploy: resources: limits: - cpus: '2' + cpus: "2" memory: 12G reservations: - cpus: '1' + cpus: "1" memory: 8G rabbitmq: image: rabbitmq ports: - "127.0.0.1:5672:5672" + volumes: + - rabbitmq_data:/var/lib/rabbitmq + restart: unless-stopped + networks: + - app-network + healthcheck: + test: ["CMD", "rabbitmq-diagnostics", "-q", "ping"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + deploy: + resources: + limits: + cpus: "2" + memory: 12G + reservations: + cpus: "1" + memory: 8G networks: app-network: @@ -84,3 +190,4 @@ networks: volumes: redis_data: + rabbitmq_data: diff --git a/backend/requirements.txt b/backend/requirements.txt index f8292068..33846135 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,4 +31,5 @@ vncdotool>=1.2.0 pydantic tavily-python>=0.5.4 pytesseract==0.3.13 -stripe>=7.0.0 \ No newline at end of file +stripe>=7.0.0 +dramatiq[rabbitmq]>=1.17.1 diff --git a/backend/run_agent_background.py b/backend/run_agent_background.py index 0c774b2b..1c330dd2 100644 --- a/backend/run_agent_background.py +++ b/backend/run_agent_background.py @@ -12,8 +12,11 @@ from agentpress.thread_manager import ThreadManager from services.supabase import DBConnection from services import redis from dramatiq.brokers.rabbitmq import RabbitmqBroker +import os -rabbitmq_broker = RabbitmqBroker(host="localhost", port="5672", middleware=[dramatiq.middleware.AsyncIO()]) +rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq') +rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672)) +rabbitmq_broker = RabbitmqBroker(host=rabbitmq_host, port=rabbitmq_port, middleware=[dramatiq.middleware.AsyncIO()]) dramatiq.set_broker(rabbitmq_broker) _initialized = False From 1e62257ab17028644c3caba3b5ab5fac69f22d59 Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Thu, 15 May 2025 06:34:17 +0000 Subject: [PATCH 3/4] chore(dramatiq): cleanup code --- backend/agent/api.py | 264 +------------------------------------------ 1 file changed, 1 insertion(+), 263 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index feb42c01..12faf4a3 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -21,7 +21,7 @@ from services.billing import check_billing_status from utils.config import config from sandbox.sandbox import create_sandbox, get_or_start_sandbox from services.llm import make_llm_api_call -from run_agent_background import run_agent_background +from run_agent_background import run_agent_background, _cleanup_redis_response_list, update_agent_run_status # Initialize shared resources router = APIRouter() @@ -117,63 +117,6 @@ async def cleanup(): await redis.close() logger.info("Completed cleanup of agent API resources") -async def update_agent_run_status( - client, - agent_run_id: str, - status: str, - error: Optional[str] = None, - responses: Optional[List[Any]] = None # Expects parsed list of dicts -) -> bool: - """ - Centralized function to update agent run status. - Returns True if update was successful. - """ - try: - update_data = { - "status": status, - "completed_at": datetime.now(timezone.utc).isoformat() - } - - if error: - update_data["error"] = error - - if responses: - # Ensure responses are stored correctly as JSONB - update_data["responses"] = responses - - # Retry up to 3 times - for retry in range(3): - try: - update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute() - - if hasattr(update_result, 'data') and update_result.data: - logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})") - - # Verify the update - verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute() - if verify_result.data: - actual_status = verify_result.data[0].get('status') - completed_at = verify_result.data[0].get('completed_at') - logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}") - return True - else: - logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}") - if retry == 2: # Last retry - logger.error(f"Failed to update agent run status after all retries: {agent_run_id}") - return False - except Exception as db_error: - logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}") - if retry < 2: # Not the last retry yet - await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff - else: - logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True) - return False - except Exception as e: - logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True) - return False - - return False - async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None): """Update database and publish stop signal to Redis.""" logger.info(f"Stopping agent run: {agent_run_id}") @@ -234,16 +177,6 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None) logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}") - -async def _cleanup_redis_response_list(agent_run_id: str): - """Set TTL on the Redis response list.""" - response_list_key = f"agent_run:{agent_run_id}:responses" - try: - await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL) - logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}") - except Exception as e: - logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}") - # async def restore_running_agent_runs(): # """Mark agent runs that were still 'running' in the database as failed and clean up Redis resources.""" # logger.info("Restoring running agent runs after server restart") @@ -302,20 +235,6 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st await verify_thread_access(client, thread_id, user_id) return agent_run_data -async def _cleanup_redis_instance_key(agent_run_id: str): - """Clean up the instance-specific Redis key for an agent run.""" - if not instance_id: - logger.warning("Instance ID not set, cannot clean up instance key.") - return - key = f"active_run:{instance_id}:{agent_run_id}" - logger.debug(f"Cleaning up Redis instance key: {key}") - try: - await redis.delete(key) - logger.debug(f"Successfully cleaned up Redis key: {key}") - except Exception as e: - logger.warning(f"Failed to clean up Redis key {key}: {str(e)}") - - async def get_or_create_project_sandbox(client, project_id: str): """Get or create a sandbox for a project.""" project = await client.table('projects').select('*').eq('project_id', project_id).execute() @@ -658,187 +577,6 @@ async def stream_agent_run( "Access-Control-Allow-Origin": "*" }) -# @dramatiq.actor -# async def run_agent_background( -# agent_run_id: str, -# thread_id: str, -# instance_id: str, # Use the global instance ID passed during initialization -# project_id: str, -# model_name: str, -# enable_thinking: Optional[bool], -# reasoning_effort: Optional[str], -# stream: bool, -# enable_context_manager: bool -# ): -# """Run the agent in the background using Redis for state.""" -# logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") -# logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})") - -# client = await db.client -# start_time = datetime.now(timezone.utc) -# total_responses = 0 -# pubsub = None -# stop_checker = None -# stop_signal_received = False - -# # Define Redis keys and channels -# response_list_key = f"agent_run:{agent_run_id}:responses" -# response_channel = f"agent_run:{agent_run_id}:new_response" -# instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" -# global_control_channel = f"agent_run:{agent_run_id}:control" -# instance_active_key = f"active_run:{instance_id}:{agent_run_id}" - -# async def check_for_stop_signal(): -# nonlocal stop_signal_received -# if not pubsub: return -# try: -# while not stop_signal_received: -# message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) -# if message and message.get("type") == "message": -# data = message.get("data") -# if isinstance(data, bytes): data = data.decode('utf-8') -# if data == "STOP": -# logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") -# stop_signal_received = True -# break -# # Periodically refresh the active run key TTL -# if total_responses % 50 == 0: # Refresh every 50 responses or so -# try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) -# except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") -# await asyncio.sleep(0.1) # Short sleep to prevent tight loop -# except asyncio.CancelledError: -# logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") -# except Exception as e: -# logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) -# stop_signal_received = True # Stop the run if the checker fails - -# try: -# # Setup Pub/Sub listener for control signals -# pubsub = await redis.create_pubsub() -# await pubsub.subscribe(instance_control_channel, global_control_channel) -# logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") -# stop_checker = asyncio.create_task(check_for_stop_signal()) - -# # Ensure active run key exists and has TTL -# await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) - -# # Initialize agent generator -# agent_gen = run_agent( -# thread_id=thread_id, project_id=project_id, stream=stream, -# thread_manager=thread_manager, model_name=model_name, -# enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, -# enable_context_manager=enable_context_manager -# ) - -# final_status = "running" -# error_message = None - -# async for response in agent_gen: -# if stop_signal_received: -# logger.info(f"Agent run {agent_run_id} stopped by signal.") -# final_status = "stopped" -# break - -# # Store response in Redis list and publish notification -# response_json = json.dumps(response) -# await redis.rpush(response_list_key, response_json) -# await redis.publish(response_channel, "new") -# total_responses += 1 - -# # Check for agent-signaled completion or error -# if response.get('type') == 'status': -# status_val = response.get('status') -# if status_val in ['completed', 'failed', 'stopped']: -# logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") -# final_status = status_val -# if status_val == 'failed' or status_val == 'stopped': -# error_message = response.get('message', f"Run ended with status: {status_val}") -# break - -# # If loop finished without explicit completion/error/stop signal, mark as completed -# if final_status == "running": -# final_status = "completed" -# duration = (datetime.now(timezone.utc) - start_time).total_seconds() -# logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") -# completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} -# await redis.rpush(response_list_key, json.dumps(completion_message)) -# await redis.publish(response_channel, "new") # Notify about the completion message - -# # Fetch final responses from Redis for DB update -# all_responses_json = await redis.lrange(response_list_key, 0, -1) -# all_responses = [json.loads(r) for r in all_responses_json] - -# # Update DB status -# await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) - -# # Publish final control signal (END_STREAM or ERROR) -# control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" -# try: -# await redis.publish(global_control_channel, control_signal) -# # No need to publish to instance channel as the run is ending on this instance -# logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") -# except Exception as e: -# logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") - -# except Exception as e: -# error_message = str(e) -# traceback_str = traceback.format_exc() -# duration = (datetime.now(timezone.utc) - start_time).total_seconds() -# logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") -# final_status = "failed" - -# # Push error message to Redis list -# error_response = {"type": "status", "status": "error", "message": error_message} -# try: -# await redis.rpush(response_list_key, json.dumps(error_response)) -# await redis.publish(response_channel, "new") -# except Exception as redis_err: -# logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") - -# # Fetch final responses (including the error) -# all_responses = [] -# try: -# all_responses_json = await redis.lrange(response_list_key, 0, -1) -# all_responses = [json.loads(r) for r in all_responses_json] -# except Exception as fetch_err: -# logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") -# all_responses = [error_response] # Use the error message we tried to push - -# # Update DB status -# await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) - -# # Publish ERROR signal -# try: -# await redis.publish(global_control_channel, "ERROR") -# logger.debug(f"Published ERROR signal to {global_control_channel}") -# except Exception as e: -# logger.warning(f"Failed to publish ERROR signal: {str(e)}") - -# finally: -# # Cleanup stop checker task -# if stop_checker and not stop_checker.done(): -# stop_checker.cancel() -# try: await stop_checker -# except asyncio.CancelledError: pass -# except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") - -# # Close pubsub connection -# if pubsub: -# try: -# await pubsub.unsubscribe() -# await pubsub.close() -# logger.debug(f"Closed pubsub connection for {agent_run_id}") -# except Exception as e: -# logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") - -# # Set TTL on the response list in Redis -# await _cleanup_redis_response_list(agent_run_id) - -# # Remove the instance-specific active run key -# await _cleanup_redis_instance_key(agent_run_id) - -# logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") - async def generate_and_update_project_name(project_id: str, prompt: str): """Generates a project name using an LLM and updates the database.""" logger.info(f"Starting background task to generate name for project: {project_id}") From 35de7f1b40397497ddea86f962551aeafb804733 Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Thu, 15 May 2025 23:28:28 +0000 Subject: [PATCH 4/4] fix(stream): redis latency issue mitigated --- backend/agent/api.py | 4 ++-- backend/agentpress/response_processor.py | 4 ++++ frontend/src/components/thread/types.ts | 1 + frontend/src/hooks/useAgentStream.ts | 29 +++++++++++++++++------- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index 618b2b0d..525c394d 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -744,8 +744,8 @@ async def run_agent_background( # Store response in Redis list and publish notification response_json = json.dumps(response) - await redis.rpush(response_list_key, response_json) - await redis.publish(response_channel, "new") + asyncio.create_task(redis.rpush(response_list_key, response_json)) + asyncio.create_task(redis.publish(response_channel, "new")) total_responses += 1 # Check for agent-signaled completion or error diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 78c95af5..ea6e028a 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -147,6 +147,8 @@ class ResponseProcessor: if assist_start_msg_obj: yield assist_start_msg_obj # --- End Start Events --- + __sequence = 0 + async for chunk in llm_response: if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason: finish_reason = chunk.choices[0].finish_reason @@ -175,12 +177,14 @@ class ResponseProcessor: # Yield ONLY content chunk (don't save) now_chunk = datetime.now(timezone.utc).isoformat() yield { + "sequence": __sequence, "message_id": None, "thread_id": thread_id, "type": "assistant", "is_llm_message": True, "content": json.dumps({"role": "assistant", "content": chunk_content}), "metadata": json.dumps({"stream_status": "chunk", "thread_run_id": thread_run_id}), "created_at": now_chunk, "updated_at": now_chunk } + __sequence += 1 else: logger.info("XML tool call limit reached - not yielding more content chunks") diff --git a/frontend/src/components/thread/types.ts b/frontend/src/components/thread/types.ts index 8aee81f9..eeb3053a 100644 --- a/frontend/src/components/thread/types.ts +++ b/frontend/src/components/thread/types.ts @@ -9,6 +9,7 @@ export type ThreadParams = { // Unified Message Interface matching the backend/database schema export interface UnifiedMessage { + sequence?: number; message_id: string | null; // Can be null for transient stream events (chunks, unsaved statuses) thread_id: string; type: 'user' | 'assistant' | 'tool' | 'system' | 'status' | 'browser_state'; // Add 'system' if used diff --git a/frontend/src/hooks/useAgentStream.ts b/frontend/src/hooks/useAgentStream.ts index 74406cb0..19c94645 100644 --- a/frontend/src/hooks/useAgentStream.ts +++ b/frontend/src/hooks/useAgentStream.ts @@ -1,4 +1,4 @@ -import { useState, useEffect, useRef, useCallback } from 'react'; +import { useState, useEffect, useRef, useCallback, useMemo } from 'react'; import { streamAgent, getAgentStatus, @@ -72,7 +72,9 @@ export function useAgentStream( ): UseAgentStreamResult { const [agentRunId, setAgentRunId] = useState(null); const [status, setStatus] = useState('idle'); - const [textContent, setTextContent] = useState(''); + const [textContent, setTextContent] = useState< + { content: string; sequence?: number }[] + >([]); const [toolCall, setToolCall] = useState(null); const [error, setError] = useState(null); @@ -82,6 +84,12 @@ export function useAgentStream( const threadIdRef = useRef(threadId); // Ref to hold the current threadId const setMessagesRef = useRef(setMessages); // Ref to hold the setMessages function + const orderedTextContent = useMemo(() => { + return textContent + .sort((a, b) => a.sequence - b.sequence) + .reduce((acc, curr) => acc + curr.content, ''); + }, [textContent]); + // Update refs if threadId or setMessages changes useEffect(() => { threadIdRef.current = threadId; @@ -148,7 +156,7 @@ export function useAgentStream( } // Reset streaming-specific state - setTextContent(''); + setTextContent([]); setToolCall(null); // Update status and clear run ID @@ -292,10 +300,15 @@ export function useAgentStream( parsedMetadata.stream_status === 'chunk' && parsedContent.content ) { - setTextContent((prev) => prev + parsedContent.content); + setTextContent((prev) => { + return prev.concat({ + sequence: message.sequence, + content: parsedContent.content, + }); + }); callbacks.onAssistantChunk?.({ content: parsedContent.content }); } else if (parsedMetadata.stream_status === 'complete') { - setTextContent(''); + setTextContent([]); setToolCall(null); if (message.message_id) callbacks.onMessage(message); } else if (!parsedMetadata.stream_status) { @@ -501,7 +514,7 @@ export function useAgentStream( } // Reset state on unmount if needed, though finalizeStream should handle most cases setStatus('idle'); - setTextContent(''); + setTextContent([]); setToolCall(null); setError(null); setAgentRunId(null); @@ -528,7 +541,7 @@ export function useAgentStream( } // Reset state before starting - setTextContent(''); + setTextContent([]); setToolCall(null); setError(null); updateStatus('connecting'); @@ -616,7 +629,7 @@ export function useAgentStream( return { status, - textContent, + textContent: orderedTextContent, toolCall, error, agentRunId,