mirror of https://github.com/kortix-ai/suna.git
feat(dramatiq): added distributed worker support using dramatiq and rabbitmq
This commit is contained in:
parent
fe05ccfb4f
commit
9b3561213d
|
@ -21,6 +21,7 @@ from services.billing import check_billing_status
|
||||||
from utils.config import config
|
from utils.config import config
|
||||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||||
from services.llm import make_llm_api_call
|
from services.llm import make_llm_api_call
|
||||||
|
from run_agent_background import run_agent_background
|
||||||
|
|
||||||
# Initialize shared resources
|
# Initialize shared resources
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -428,19 +429,15 @@ async def start_agent(
|
||||||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||||
|
|
||||||
# Run the agent in the background
|
# Run the agent in the background
|
||||||
task = asyncio.create_task(
|
run_agent_background.send(
|
||||||
run_agent_background(
|
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
project_id=project_id,
|
||||||
project_id=project_id, sandbox=sandbox,
|
model_name=model_name, # Already resolved above
|
||||||
model_name=model_name, # Already resolved above
|
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
|
||||||
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
|
stream=body.stream, enable_context_manager=body.enable_context_manager
|
||||||
stream=body.stream, enable_context_manager=body.enable_context_manager
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set a callback to clean up Redis instance key when task is done
|
# Set a callback to clean up Redis instance key when task is done
|
||||||
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
|
|
||||||
|
|
||||||
return {"agent_run_id": agent_run_id, "status": "running"}
|
return {"agent_run_id": agent_run_id, "status": "running"}
|
||||||
|
|
||||||
@router.post("/agent-run/{agent_run_id}/stop")
|
@router.post("/agent-run/{agent_run_id}/stop")
|
||||||
|
@ -661,186 +658,186 @@ async def stream_agent_run(
|
||||||
"Access-Control-Allow-Origin": "*"
|
"Access-Control-Allow-Origin": "*"
|
||||||
})
|
})
|
||||||
|
|
||||||
async def run_agent_background(
|
# @dramatiq.actor
|
||||||
agent_run_id: str,
|
# async def run_agent_background(
|
||||||
thread_id: str,
|
# agent_run_id: str,
|
||||||
instance_id: str, # Use the global instance ID passed during initialization
|
# thread_id: str,
|
||||||
project_id: str,
|
# instance_id: str, # Use the global instance ID passed during initialization
|
||||||
sandbox,
|
# project_id: str,
|
||||||
model_name: str,
|
# model_name: str,
|
||||||
enable_thinking: Optional[bool],
|
# enable_thinking: Optional[bool],
|
||||||
reasoning_effort: Optional[str],
|
# reasoning_effort: Optional[str],
|
||||||
stream: bool,
|
# stream: bool,
|
||||||
enable_context_manager: bool
|
# enable_context_manager: bool
|
||||||
):
|
# ):
|
||||||
"""Run the agent in the background using Redis for state."""
|
# """Run the agent in the background using Redis for state."""
|
||||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
# logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||||
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
# logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||||
|
|
||||||
client = await db.client
|
# client = await db.client
|
||||||
start_time = datetime.now(timezone.utc)
|
# start_time = datetime.now(timezone.utc)
|
||||||
total_responses = 0
|
# total_responses = 0
|
||||||
pubsub = None
|
# pubsub = None
|
||||||
stop_checker = None
|
# stop_checker = None
|
||||||
stop_signal_received = False
|
# stop_signal_received = False
|
||||||
|
|
||||||
# Define Redis keys and channels
|
# # Define Redis keys and channels
|
||||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
# response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||||
response_channel = f"agent_run:{agent_run_id}:new_response"
|
# response_channel = f"agent_run:{agent_run_id}:new_response"
|
||||||
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
# instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
||||||
global_control_channel = f"agent_run:{agent_run_id}:control"
|
# global_control_channel = f"agent_run:{agent_run_id}:control"
|
||||||
instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
|
# instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||||
|
|
||||||
async def check_for_stop_signal():
|
# async def check_for_stop_signal():
|
||||||
nonlocal stop_signal_received
|
# nonlocal stop_signal_received
|
||||||
if not pubsub: return
|
# if not pubsub: return
|
||||||
try:
|
# try:
|
||||||
while not stop_signal_received:
|
# while not stop_signal_received:
|
||||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
|
# message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
|
||||||
if message and message.get("type") == "message":
|
# if message and message.get("type") == "message":
|
||||||
data = message.get("data")
|
# data = message.get("data")
|
||||||
if isinstance(data, bytes): data = data.decode('utf-8')
|
# if isinstance(data, bytes): data = data.decode('utf-8')
|
||||||
if data == "STOP":
|
# if data == "STOP":
|
||||||
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
# logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
||||||
stop_signal_received = True
|
# stop_signal_received = True
|
||||||
break
|
# break
|
||||||
# Periodically refresh the active run key TTL
|
# # Periodically refresh the active run key TTL
|
||||||
if total_responses % 50 == 0: # Refresh every 50 responses or so
|
# if total_responses % 50 == 0: # Refresh every 50 responses or so
|
||||||
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
|
# try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
|
||||||
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
|
# except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
|
||||||
await asyncio.sleep(0.1) # Short sleep to prevent tight loop
|
# await asyncio.sleep(0.1) # Short sleep to prevent tight loop
|
||||||
except asyncio.CancelledError:
|
# except asyncio.CancelledError:
|
||||||
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
# logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
# logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||||
stop_signal_received = True # Stop the run if the checker fails
|
# stop_signal_received = True # Stop the run if the checker fails
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
# Setup Pub/Sub listener for control signals
|
# # Setup Pub/Sub listener for control signals
|
||||||
pubsub = await redis.create_pubsub()
|
# pubsub = await redis.create_pubsub()
|
||||||
await pubsub.subscribe(instance_control_channel, global_control_channel)
|
# await pubsub.subscribe(instance_control_channel, global_control_channel)
|
||||||
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
# logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
||||||
stop_checker = asyncio.create_task(check_for_stop_signal())
|
# stop_checker = asyncio.create_task(check_for_stop_signal())
|
||||||
|
|
||||||
# Ensure active run key exists and has TTL
|
# # Ensure active run key exists and has TTL
|
||||||
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
# await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||||
|
|
||||||
# Initialize agent generator
|
# # Initialize agent generator
|
||||||
agent_gen = run_agent(
|
# agent_gen = run_agent(
|
||||||
thread_id=thread_id, project_id=project_id, stream=stream,
|
# thread_id=thread_id, project_id=project_id, stream=stream,
|
||||||
thread_manager=thread_manager, model_name=model_name,
|
# thread_manager=thread_manager, model_name=model_name,
|
||||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
# enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||||
enable_context_manager=enable_context_manager
|
# enable_context_manager=enable_context_manager
|
||||||
)
|
# )
|
||||||
|
|
||||||
final_status = "running"
|
# final_status = "running"
|
||||||
error_message = None
|
# error_message = None
|
||||||
|
|
||||||
async for response in agent_gen:
|
# async for response in agent_gen:
|
||||||
if stop_signal_received:
|
# if stop_signal_received:
|
||||||
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
# logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||||
final_status = "stopped"
|
# final_status = "stopped"
|
||||||
break
|
# break
|
||||||
|
|
||||||
# Store response in Redis list and publish notification
|
# # Store response in Redis list and publish notification
|
||||||
response_json = json.dumps(response)
|
# response_json = json.dumps(response)
|
||||||
await redis.rpush(response_list_key, response_json)
|
# await redis.rpush(response_list_key, response_json)
|
||||||
await redis.publish(response_channel, "new")
|
# await redis.publish(response_channel, "new")
|
||||||
total_responses += 1
|
# total_responses += 1
|
||||||
|
|
||||||
# Check for agent-signaled completion or error
|
# # Check for agent-signaled completion or error
|
||||||
if response.get('type') == 'status':
|
# if response.get('type') == 'status':
|
||||||
status_val = response.get('status')
|
# status_val = response.get('status')
|
||||||
if status_val in ['completed', 'failed', 'stopped']:
|
# if status_val in ['completed', 'failed', 'stopped']:
|
||||||
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
|
# logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
|
||||||
final_status = status_val
|
# final_status = status_val
|
||||||
if status_val == 'failed' or status_val == 'stopped':
|
# if status_val == 'failed' or status_val == 'stopped':
|
||||||
error_message = response.get('message', f"Run ended with status: {status_val}")
|
# error_message = response.get('message', f"Run ended with status: {status_val}")
|
||||||
break
|
# break
|
||||||
|
|
||||||
# If loop finished without explicit completion/error/stop signal, mark as completed
|
# # If loop finished without explicit completion/error/stop signal, mark as completed
|
||||||
if final_status == "running":
|
# if final_status == "running":
|
||||||
final_status = "completed"
|
# final_status = "completed"
|
||||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
# duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||||
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
# logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
||||||
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
# completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
||||||
await redis.rpush(response_list_key, json.dumps(completion_message))
|
# await redis.rpush(response_list_key, json.dumps(completion_message))
|
||||||
await redis.publish(response_channel, "new") # Notify about the completion message
|
# await redis.publish(response_channel, "new") # Notify about the completion message
|
||||||
|
|
||||||
# Fetch final responses from Redis for DB update
|
# # Fetch final responses from Redis for DB update
|
||||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
# all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||||
all_responses = [json.loads(r) for r in all_responses_json]
|
# all_responses = [json.loads(r) for r in all_responses_json]
|
||||||
|
|
||||||
# Update DB status
|
# # Update DB status
|
||||||
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
|
# await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
|
||||||
|
|
||||||
# Publish final control signal (END_STREAM or ERROR)
|
# # Publish final control signal (END_STREAM or ERROR)
|
||||||
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
|
# control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
|
||||||
try:
|
# try:
|
||||||
await redis.publish(global_control_channel, control_signal)
|
# await redis.publish(global_control_channel, control_signal)
|
||||||
# No need to publish to instance channel as the run is ending on this instance
|
# # No need to publish to instance channel as the run is ending on this instance
|
||||||
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
|
# logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
|
# logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
|
||||||
|
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
error_message = str(e)
|
# error_message = str(e)
|
||||||
traceback_str = traceback.format_exc()
|
# traceback_str = traceback.format_exc()
|
||||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
# duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||||
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
# logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
||||||
final_status = "failed"
|
# final_status = "failed"
|
||||||
|
|
||||||
# Push error message to Redis list
|
# # Push error message to Redis list
|
||||||
error_response = {"type": "status", "status": "error", "message": error_message}
|
# error_response = {"type": "status", "status": "error", "message": error_message}
|
||||||
try:
|
# try:
|
||||||
await redis.rpush(response_list_key, json.dumps(error_response))
|
# await redis.rpush(response_list_key, json.dumps(error_response))
|
||||||
await redis.publish(response_channel, "new")
|
# await redis.publish(response_channel, "new")
|
||||||
except Exception as redis_err:
|
# except Exception as redis_err:
|
||||||
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
|
# logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
|
||||||
|
|
||||||
# Fetch final responses (including the error)
|
# # Fetch final responses (including the error)
|
||||||
all_responses = []
|
# all_responses = []
|
||||||
try:
|
# try:
|
||||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
# all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||||
all_responses = [json.loads(r) for r in all_responses_json]
|
# all_responses = [json.loads(r) for r in all_responses_json]
|
||||||
except Exception as fetch_err:
|
# except Exception as fetch_err:
|
||||||
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
|
# logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
|
||||||
all_responses = [error_response] # Use the error message we tried to push
|
# all_responses = [error_response] # Use the error message we tried to push
|
||||||
|
|
||||||
# Update DB status
|
# # Update DB status
|
||||||
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
|
# await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
|
||||||
|
|
||||||
# Publish ERROR signal
|
# # Publish ERROR signal
|
||||||
try:
|
# try:
|
||||||
await redis.publish(global_control_channel, "ERROR")
|
# await redis.publish(global_control_channel, "ERROR")
|
||||||
logger.debug(f"Published ERROR signal to {global_control_channel}")
|
# logger.debug(f"Published ERROR signal to {global_control_channel}")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Failed to publish ERROR signal: {str(e)}")
|
# logger.warning(f"Failed to publish ERROR signal: {str(e)}")
|
||||||
|
|
||||||
finally:
|
# finally:
|
||||||
# Cleanup stop checker task
|
# # Cleanup stop checker task
|
||||||
if stop_checker and not stop_checker.done():
|
# if stop_checker and not stop_checker.done():
|
||||||
stop_checker.cancel()
|
# stop_checker.cancel()
|
||||||
try: await stop_checker
|
# try: await stop_checker
|
||||||
except asyncio.CancelledError: pass
|
# except asyncio.CancelledError: pass
|
||||||
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
|
# except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
|
||||||
|
|
||||||
# Close pubsub connection
|
# # Close pubsub connection
|
||||||
if pubsub:
|
# if pubsub:
|
||||||
try:
|
# try:
|
||||||
await pubsub.unsubscribe()
|
# await pubsub.unsubscribe()
|
||||||
await pubsub.close()
|
# await pubsub.close()
|
||||||
logger.debug(f"Closed pubsub connection for {agent_run_id}")
|
# logger.debug(f"Closed pubsub connection for {agent_run_id}")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
|
# logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
|
||||||
|
|
||||||
# Set TTL on the response list in Redis
|
# # Set TTL on the response list in Redis
|
||||||
await _cleanup_redis_response_list(agent_run_id)
|
# await _cleanup_redis_response_list(agent_run_id)
|
||||||
|
|
||||||
# Remove the instance-specific active run key
|
# # Remove the instance-specific active run key
|
||||||
await _cleanup_redis_instance_key(agent_run_id)
|
# await _cleanup_redis_instance_key(agent_run_id)
|
||||||
|
|
||||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
|
# logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
|
||||||
|
|
||||||
async def generate_and_update_project_name(project_id: str, prompt: str):
|
async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||||
"""Generates a project name using an LLM and updates the database."""
|
"""Generates a project name using an LLM and updates the database."""
|
||||||
|
@ -1030,16 +1027,13 @@ async def initiate_agent_with_files(
|
||||||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||||
|
|
||||||
# Run agent in background
|
# Run agent in background
|
||||||
task = asyncio.create_task(
|
run_agent_background.send(
|
||||||
run_agent_background(
|
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
project_id=project_id,
|
||||||
project_id=project_id, sandbox=sandbox,
|
model_name=model_name, # Already resolved above
|
||||||
model_name=model_name, # Already resolved above
|
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
stream=stream, enable_context_manager=enable_context_manager
|
||||||
stream=stream, enable_context_manager=enable_context_manager
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
|
|
||||||
|
|
||||||
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
||||||
|
|
||||||
|
|
|
@ -73,9 +73,14 @@ services:
|
||||||
cpus: '1'
|
cpus: '1'
|
||||||
memory: 8G
|
memory: 8G
|
||||||
|
|
||||||
|
rabbitmq:
|
||||||
|
image: rabbitmq
|
||||||
|
ports:
|
||||||
|
- "127.0.0.1:5672:5672"
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
app-network:
|
app-network:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
redis_data:
|
redis_data:
|
||||||
|
|
|
@ -0,0 +1,306 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
from services import redis
|
||||||
|
from agent.run import run_agent
|
||||||
|
from utils.logger import logger
|
||||||
|
import dramatiq
|
||||||
|
import uuid
|
||||||
|
from agentpress.thread_manager import ThreadManager
|
||||||
|
from services.supabase import DBConnection
|
||||||
|
from services import redis
|
||||||
|
from dramatiq.brokers.rabbitmq import RabbitmqBroker
|
||||||
|
|
||||||
|
rabbitmq_broker = RabbitmqBroker(host="localhost", port="5672", middleware=[dramatiq.middleware.AsyncIO()])
|
||||||
|
dramatiq.set_broker(rabbitmq_broker)
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
db = DBConnection()
|
||||||
|
thread_manager = None
|
||||||
|
instance_id = "single"
|
||||||
|
|
||||||
|
async def initialize():
|
||||||
|
"""Initialize the agent API with resources from the main API."""
|
||||||
|
global thread_manager, db, instance_id, _initialized
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use provided instance_id or generate a new one
|
||||||
|
if not instance_id:
|
||||||
|
# Generate instance ID
|
||||||
|
instance_id = str(uuid.uuid4())[:8]
|
||||||
|
await redis.initialize_async()
|
||||||
|
await db.initialize()
|
||||||
|
thread_manager = ThreadManager()
|
||||||
|
|
||||||
|
_initialized = True
|
||||||
|
logger.info(f"Initialized agent API with instance ID: {instance_id}")
|
||||||
|
|
||||||
|
|
||||||
|
@dramatiq.actor
|
||||||
|
async def run_agent_background(
|
||||||
|
agent_run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
instance_id: str, # Use the global instance ID passed during initialization
|
||||||
|
project_id: str,
|
||||||
|
model_name: str,
|
||||||
|
enable_thinking: Optional[bool],
|
||||||
|
reasoning_effort: Optional[str],
|
||||||
|
stream: bool,
|
||||||
|
enable_context_manager: bool
|
||||||
|
):
|
||||||
|
"""Run the agent in the background using Redis for state."""
|
||||||
|
await initialize()
|
||||||
|
|
||||||
|
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||||
|
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||||
|
|
||||||
|
client = await db.client
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
|
total_responses = 0
|
||||||
|
pubsub = None
|
||||||
|
stop_checker = None
|
||||||
|
stop_signal_received = False
|
||||||
|
|
||||||
|
# Define Redis keys and channels
|
||||||
|
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||||
|
response_channel = f"agent_run:{agent_run_id}:new_response"
|
||||||
|
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
||||||
|
global_control_channel = f"agent_run:{agent_run_id}:control"
|
||||||
|
instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||||
|
|
||||||
|
async def check_for_stop_signal():
|
||||||
|
nonlocal stop_signal_received
|
||||||
|
if not pubsub: return
|
||||||
|
try:
|
||||||
|
while not stop_signal_received:
|
||||||
|
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
|
||||||
|
if message and message.get("type") == "message":
|
||||||
|
data = message.get("data")
|
||||||
|
if isinstance(data, bytes): data = data.decode('utf-8')
|
||||||
|
if data == "STOP":
|
||||||
|
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
||||||
|
stop_signal_received = True
|
||||||
|
break
|
||||||
|
# Periodically refresh the active run key TTL
|
||||||
|
if total_responses % 50 == 0: # Refresh every 50 responses or so
|
||||||
|
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
|
||||||
|
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
|
||||||
|
await asyncio.sleep(0.1) # Short sleep to prevent tight loop
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||||
|
stop_signal_received = True # Stop the run if the checker fails
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Setup Pub/Sub listener for control signals
|
||||||
|
pubsub = await redis.create_pubsub()
|
||||||
|
await pubsub.subscribe(instance_control_channel, global_control_channel)
|
||||||
|
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
||||||
|
stop_checker = asyncio.create_task(check_for_stop_signal())
|
||||||
|
|
||||||
|
# Ensure active run key exists and has TTL
|
||||||
|
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||||
|
|
||||||
|
# Initialize agent generator
|
||||||
|
agent_gen = run_agent(
|
||||||
|
thread_id=thread_id, project_id=project_id, stream=stream,
|
||||||
|
thread_manager=thread_manager, model_name=model_name,
|
||||||
|
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||||
|
enable_context_manager=enable_context_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
final_status = "running"
|
||||||
|
error_message = None
|
||||||
|
|
||||||
|
async for response in agent_gen:
|
||||||
|
if stop_signal_received:
|
||||||
|
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||||
|
final_status = "stopped"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Store response in Redis list and publish notification
|
||||||
|
response_json = json.dumps(response)
|
||||||
|
await redis.rpush(response_list_key, response_json)
|
||||||
|
await redis.publish(response_channel, "new")
|
||||||
|
total_responses += 1
|
||||||
|
|
||||||
|
# Check for agent-signaled completion or error
|
||||||
|
if response.get('type') == 'status':
|
||||||
|
status_val = response.get('status')
|
||||||
|
if status_val in ['completed', 'failed', 'stopped']:
|
||||||
|
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
|
||||||
|
final_status = status_val
|
||||||
|
if status_val == 'failed' or status_val == 'stopped':
|
||||||
|
error_message = response.get('message', f"Run ended with status: {status_val}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# If loop finished without explicit completion/error/stop signal, mark as completed
|
||||||
|
if final_status == "running":
|
||||||
|
final_status = "completed"
|
||||||
|
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||||
|
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
||||||
|
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
||||||
|
await redis.rpush(response_list_key, json.dumps(completion_message))
|
||||||
|
await redis.publish(response_channel, "new") # Notify about the completion message
|
||||||
|
|
||||||
|
# Fetch final responses from Redis for DB update
|
||||||
|
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||||
|
all_responses = [json.loads(r) for r in all_responses_json]
|
||||||
|
|
||||||
|
# Update DB status
|
||||||
|
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
|
||||||
|
|
||||||
|
# Publish final control signal (END_STREAM or ERROR)
|
||||||
|
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
|
||||||
|
try:
|
||||||
|
await redis.publish(global_control_channel, control_signal)
|
||||||
|
# No need to publish to instance channel as the run is ending on this instance
|
||||||
|
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
traceback_str = traceback.format_exc()
|
||||||
|
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||||
|
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
||||||
|
final_status = "failed"
|
||||||
|
|
||||||
|
# Push error message to Redis list
|
||||||
|
error_response = {"type": "status", "status": "error", "message": error_message}
|
||||||
|
try:
|
||||||
|
await redis.rpush(response_list_key, json.dumps(error_response))
|
||||||
|
await redis.publish(response_channel, "new")
|
||||||
|
except Exception as redis_err:
|
||||||
|
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
|
||||||
|
|
||||||
|
# Fetch final responses (including the error)
|
||||||
|
all_responses = []
|
||||||
|
try:
|
||||||
|
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||||
|
all_responses = [json.loads(r) for r in all_responses_json]
|
||||||
|
except Exception as fetch_err:
|
||||||
|
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
|
||||||
|
all_responses = [error_response] # Use the error message we tried to push
|
||||||
|
|
||||||
|
# Update DB status
|
||||||
|
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
|
||||||
|
|
||||||
|
# Publish ERROR signal
|
||||||
|
try:
|
||||||
|
await redis.publish(global_control_channel, "ERROR")
|
||||||
|
logger.debug(f"Published ERROR signal to {global_control_channel}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish ERROR signal: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup stop checker task
|
||||||
|
if stop_checker and not stop_checker.done():
|
||||||
|
stop_checker.cancel()
|
||||||
|
try: await stop_checker
|
||||||
|
except asyncio.CancelledError: pass
|
||||||
|
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
|
||||||
|
|
||||||
|
# Close pubsub connection
|
||||||
|
if pubsub:
|
||||||
|
try:
|
||||||
|
await pubsub.unsubscribe()
|
||||||
|
await pubsub.close()
|
||||||
|
logger.debug(f"Closed pubsub connection for {agent_run_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
|
||||||
|
|
||||||
|
# Set TTL on the response list in Redis
|
||||||
|
await _cleanup_redis_response_list(agent_run_id)
|
||||||
|
|
||||||
|
# Remove the instance-specific active run key
|
||||||
|
await _cleanup_redis_instance_key(agent_run_id)
|
||||||
|
|
||||||
|
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_redis_instance_key(agent_run_id: str):
|
||||||
|
"""Clean up the instance-specific Redis key for an agent run."""
|
||||||
|
if not instance_id:
|
||||||
|
logger.warning("Instance ID not set, cannot clean up instance key.")
|
||||||
|
return
|
||||||
|
key = f"active_run:{instance_id}:{agent_run_id}"
|
||||||
|
logger.debug(f"Cleaning up Redis instance key: {key}")
|
||||||
|
try:
|
||||||
|
await redis.delete(key)
|
||||||
|
logger.debug(f"Successfully cleaned up Redis key: {key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up Redis key {key}: {str(e)}")
|
||||||
|
|
||||||
|
# TTL for Redis response lists (24 hours)
|
||||||
|
REDIS_RESPONSE_LIST_TTL = 3600 * 24
|
||||||
|
|
||||||
|
async def _cleanup_redis_response_list(agent_run_id: str):
|
||||||
|
"""Set TTL on the Redis response list."""
|
||||||
|
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||||
|
try:
|
||||||
|
await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL)
|
||||||
|
logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}")
|
||||||
|
|
||||||
|
async def update_agent_run_status(
|
||||||
|
client,
|
||||||
|
agent_run_id: str,
|
||||||
|
status: str,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
responses: Optional[list[any]] = None # Expects parsed list of dicts
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Centralized function to update agent run status.
|
||||||
|
Returns True if update was successful.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
update_data = {
|
||||||
|
"status": status,
|
||||||
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if error:
|
||||||
|
update_data["error"] = error
|
||||||
|
|
||||||
|
if responses:
|
||||||
|
# Ensure responses are stored correctly as JSONB
|
||||||
|
update_data["responses"] = responses
|
||||||
|
|
||||||
|
# Retry up to 3 times
|
||||||
|
for retry in range(3):
|
||||||
|
try:
|
||||||
|
update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute()
|
||||||
|
|
||||||
|
if hasattr(update_result, 'data') and update_result.data:
|
||||||
|
logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})")
|
||||||
|
|
||||||
|
# Verify the update
|
||||||
|
verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute()
|
||||||
|
if verify_result.data:
|
||||||
|
actual_status = verify_result.data[0].get('status')
|
||||||
|
completed_at = verify_result.data[0].get('completed_at')
|
||||||
|
logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}")
|
||||||
|
if retry == 2: # Last retry
|
||||||
|
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
||||||
|
return False
|
||||||
|
except Exception as db_error:
|
||||||
|
logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}")
|
||||||
|
if retry < 2: # Not the last retry yet
|
||||||
|
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
Loading…
Reference in New Issue