refactor(agent): remove worker_health script and migrate run_agent_background functionality to run_agent module

This commit is contained in:
sharath 2025-07-07 20:27:45 +00:00
parent 619fe78969
commit 3bb7219bef
No known key found for this signature in database
4 changed files with 78 additions and 126 deletions

View File

@ -5,14 +5,11 @@ import json
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
import uuid import uuid
from typing import Optional, List, Dict, Any, AsyncIterable from typing import Optional, List, Dict, Any
import jwt
from pydantic import BaseModel from pydantic import BaseModel
import tempfile
import os import os
from resumable_stream.runtime import create_resumable_stream_context, ResumableStreamContext from resumable_stream.runtime import create_resumable_stream_context, ResumableStreamContext
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection from services.supabase import DBConnection
from services import redis from services import redis
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
@ -21,7 +18,7 @@ from services.billing import check_billing_status, can_use_model
from utils.config import config from utils.config import config
from sandbox.sandbox import create_sandbox, delete_sandbox, get_or_start_sandbox from sandbox.sandbox import create_sandbox, delete_sandbox, get_or_start_sandbox
from services.llm import make_llm_api_call from services.llm import make_llm_api_call
from run_agent_background import run_agent_background, update_agent_run_status from agent.run_agent import run_agent_run_stream, update_agent_run_status, StreamBroadcaster
from utils.constants import MODEL_NAME_ALIASES from utils.constants import MODEL_NAME_ALIASES
from flags.flags import is_enabled from flags.flags import is_enabled
@ -35,43 +32,6 @@ REDIS_RESPONSE_LIST_TTL = 3600 * 24
stream_context_global: Optional[ResumableStreamContext] = None stream_context_global: Optional[ResumableStreamContext] = None
# Create stream broadcaster for multiple consumers
class StreamBroadcaster:
def __init__(self, source: AsyncIterable[Any]):
self.source = source
self.queues: List[asyncio.Queue] = []
def add_consumer(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue()
self.queues.append(q)
return q
async def start(self) -> None:
async for chunk in self.source:
for q in self.queues:
await q.put(chunk)
for q in self.queues:
await q.put(None) # Sentinel to close consumers
# Consumer wrapper as an async generator
@staticmethod
async def queue_to_stream(queue: asyncio.Queue) -> AsyncIterable[Any]:
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
# Print consumer task
@staticmethod
async def iterate_bg(queue: asyncio.Queue) -> None:
while True:
chunk = await queue.get()
if chunk is None:
break
pass
async def get_stream_context(): async def get_stream_context():
global stream_context_global global stream_context_global
if stream_context_global: if stream_context_global:
@ -438,8 +398,6 @@ async def start_agent(
except Exception as e: except Exception as e:
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
request_id = structlog.contextvars.get_contextvars().get('request_id')
return {"agent_run_id": agent_run_id, "status": "running"} return {"agent_run_id": agent_run_id, "status": "running"}
@router.post("/agent-run/{agent_run_id}/stop") @router.post("/agent-run/{agent_run_id}/stop")
@ -718,7 +676,7 @@ async def stream_agent_run(
return return
# Create the stream # Create the stream
stream = await stream_context.resumable_stream(agent_run_id, lambda: run_agent_background( stream = await stream_context.resumable_stream(agent_run_id, lambda: run_agent_run_stream(
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
project_id=project_id, project_id=project_id,
model_name=model_name, model_name=model_name,
@ -1037,8 +995,6 @@ async def initiate_agent_with_files(
except Exception as e: except Exception as e:
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
request_id = structlog.contextvars.get_contextvars().get('request_id')
return {"thread_id": thread_id, "agent_run_id": agent_run_id} return {"thread_id": thread_id, "agent_run_id": agent_run_id}
except Exception as e: except Exception as e:

View File

@ -6,7 +6,7 @@ import sentry
import asyncio import asyncio
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, AsyncIterable
from services import redis from services import redis
from agent.run import run_agent from agent.run import run_agent
from utils.logger import logger, structlog from utils.logger import logger, structlog
@ -22,6 +22,43 @@ db = DBConnection()
instance_id = "single" instance_id = "single"
# Create stream broadcaster for multiple consumers
class StreamBroadcaster:
def __init__(self, source: AsyncIterable[Any]):
self.source = source
self.queues: List[asyncio.Queue] = []
def add_consumer(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue()
self.queues.append(q)
return q
async def start(self) -> None:
async for chunk in self.source:
for q in self.queues:
await q.put(chunk)
for q in self.queues:
await q.put(None) # Sentinel to close consumers
# Consumer wrapper as an async generator
@staticmethod
async def queue_to_stream(queue: asyncio.Queue) -> AsyncIterable[Any]:
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
# Print consumer task
@staticmethod
async def iterate_bg(queue: asyncio.Queue) -> None:
while True:
chunk = await queue.get()
if chunk is None:
break
pass
async def initialize(): async def initialize():
"""Initialize the agent API with resources from the main API.""" """Initialize the agent API with resources from the main API."""
global db, instance_id, _initialized global db, instance_id, _initialized
@ -41,7 +78,7 @@ async def check_health(key: str):
await redis.set(key, "healthy", ex=redis.REDIS_KEY_TTL) await redis.set(key, "healthy", ex=redis.REDIS_KEY_TTL)
async def run_agent_background( async def run_agent_run_stream(
agent_run_id: str, agent_run_id: str,
thread_id: str, thread_id: str,
instance_id: str, instance_id: str,
@ -107,22 +144,28 @@ async def run_agent_background(
stop_signal_received = False stop_signal_received = False
stop_redis_key = f"stop_signal:{agent_run_id}" stop_redis_key = f"stop_signal:{agent_run_id}"
async def check_for_stop_signal(): async def check_for_stop_signal():
nonlocal stop_signal_received nonlocal stop_signal_received
try: try:
while not stop_signal_received: while not stop_signal_received:
message = await redis.client.get(stop_redis_key) message = await redis.client.get(stop_redis_key)
if message == "STOP": if message == "STOP":
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") logger.info(
f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})"
)
stop_signal_received = True stop_signal_received = True
break break
await asyncio.sleep(0.5) # Short sleep to prevent tight loop await asyncio.sleep(0.5) # Short sleep to prevent tight loop
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") logger.info(
f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})"
)
except Exception as e: except Exception as e:
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) logger.error(
stop_signal_received = True # Stop the run if the checker fails f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True
)
stop_signal_received = True # Stop the run if the checker fails
asyncio.create_task(check_for_stop_signal()) asyncio.create_task(check_for_stop_signal())
@ -150,7 +193,9 @@ async def run_agent_background(
if stop_signal_received: if stop_signal_received:
logger.info(f"Agent run {agent_run_id} stopped by signal.") logger.info(f"Agent run {agent_run_id} stopped by signal.")
final_status = "stopped" final_status = "stopped"
trace.span(name="agent_run_stopped").end(status_message="agent_run_stopped", level="WARNING") trace.span(name="agent_run_stopped").end(
status_message="agent_run_stopped", level="WARNING"
)
break break
all_responses.append(response) # Keep for DB updates all_responses.append(response) # Keep for DB updates
@ -158,7 +203,7 @@ async def run_agent_background(
yield f"data: {json.dumps(response)}\n\n" yield f"data: {json.dumps(response)}\n\n"
else: else:
yield f"data: {response}\n\n" yield f"data: {response}\n\n"
# Check for agent-signaled completion or error # Check for agent-signaled completion or error
if response.get("type") == "status": if response.get("type") == "status":
status_val = response.get("status") status_val = response.get("status")

View File

@ -249,46 +249,32 @@ Please respond appropriately to this trigger event."""
"agent_id": agent_config['agent_id'], "agent_id": agent_config['agent_id'],
"agent_version_id": agent_config.get('current_version_id'), "agent_version_id": agent_config.get('current_version_id'),
"status": "running", "status": "running",
"started_at": datetime.now(timezone.utc).isoformat() "started_at": datetime.now(timezone.utc).isoformat(),
"metadata": {
"model_name": "anthropic/claude-sonnet-4-20250514",
"enable_thinking": False,
"reasoning_effort": "low",
"enable_context_manager": True,
"trigger_execution": True,
"trigger_variables": trigger_variables
}
} }
agent_run = await client.table('agent_runs').insert(agent_run_data).execute() agent_run = await client.table('agent_runs').insert(agent_run_data).execute()
agent_run_id = agent_run.data[0]['id'] agent_run_id = agent_run.data[0]['id']
# Import and use the existing agent background execution # Register this run in Redis with TTL using trigger executor instance ID
instance_id = "trigger_executor"
instance_key = f"active_run:{instance_id}:{agent_run_id}"
try: try:
from run_agent_background import run_agent_background from services import redis
await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL)
# Start agent execution in background logger.info(f"Registered trigger agent run in Redis ({instance_key})")
run_agent_background.send( except Exception as e:
agent_run_id=agent_run_id, logger.warning(f"Failed to register trigger agent run in Redis ({instance_key}): {str(e)}")
thread_id=thread_id,
instance_id="trigger_executor", logger.info(f"Created trigger agent run: {agent_run_id}")
project_id=project_id, return agent_run_id
model_name="anthropic/claude-sonnet-4-20250514",
enable_thinking=False,
reasoning_effort="low",
stream=False,
enable_context_manager=True,
agent_config=agent_config,
is_agent_builder=False,
target_agent_id=None,
request_id=None
)
logger.info(f"Started background agent execution for trigger (run_id: {agent_run_id})")
return agent_run_id
except ImportError:
# Fallback if background execution is not available
logger.warning("Background agent execution not available, marking as completed")
await client.table('agent_runs').update({
"status": "completed",
"completed_at": datetime.now(timezone.utc).isoformat(),
"error": "Background execution not available"
}).eq('id', agent_run_id).execute()
return agent_run_id
class TriggerResponseHandler: class TriggerResponseHandler:
"""Handles responses back to external services when agents complete.""" """Handles responses back to external services when agents complete."""

View File

@ -1,35 +0,0 @@
import dotenv
dotenv.load_dotenv()
from utils.logger import logger
import run_agent_background
from services import redis
import asyncio
from utils.retry import retry
import uuid
async def main():
await retry(lambda: redis.initialize_async())
key = uuid.uuid4().hex
run_agent_background.check_health.send(key)
timeout = 20 # seconds
elapsed = 0
while elapsed < timeout:
if await redis.get(key) == "healthy":
break
await asyncio.sleep(1)
elapsed += 1
if elapsed >= timeout:
logger.critical("Health check timed out")
exit(1)
else:
logger.critical("Health check passed")
await redis.delete(key)
await redis.close()
exit(0)
if __name__ == "__main__":
asyncio.run(main())