mirror of https://github.com/kortix-ai/suna.git
refactor(agent): remove worker_health script and migrate run_agent_background functionality to run_agent module
This commit is contained in:
parent
619fe78969
commit
3bb7219bef
|
@ -5,14 +5,11 @@ import json
|
|||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any, AsyncIterable
|
||||
import jwt
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
import tempfile
|
||||
import os
|
||||
from resumable_stream.runtime import create_resumable_stream_context, ResumableStreamContext
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
from services import redis
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
|
||||
|
@ -21,7 +18,7 @@ from services.billing import check_billing_status, can_use_model
|
|||
from utils.config import config
|
||||
from sandbox.sandbox import create_sandbox, delete_sandbox, get_or_start_sandbox
|
||||
from services.llm import make_llm_api_call
|
||||
from run_agent_background import run_agent_background, update_agent_run_status
|
||||
from agent.run_agent import run_agent_run_stream, update_agent_run_status, StreamBroadcaster
|
||||
from utils.constants import MODEL_NAME_ALIASES
|
||||
from flags.flags import is_enabled
|
||||
|
||||
|
@ -35,43 +32,6 @@ REDIS_RESPONSE_LIST_TTL = 3600 * 24
|
|||
|
||||
stream_context_global: Optional[ResumableStreamContext] = None
|
||||
|
||||
# Create stream broadcaster for multiple consumers
|
||||
class StreamBroadcaster:
|
||||
def __init__(self, source: AsyncIterable[Any]):
|
||||
self.source = source
|
||||
self.queues: List[asyncio.Queue] = []
|
||||
|
||||
def add_consumer(self) -> asyncio.Queue:
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
self.queues.append(q)
|
||||
return q
|
||||
|
||||
async def start(self) -> None:
|
||||
async for chunk in self.source:
|
||||
for q in self.queues:
|
||||
await q.put(chunk)
|
||||
for q in self.queues:
|
||||
await q.put(None) # Sentinel to close consumers
|
||||
|
||||
# Consumer wrapper as an async generator
|
||||
@staticmethod
|
||||
async def queue_to_stream(queue: asyncio.Queue) -> AsyncIterable[Any]:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
# Print consumer task
|
||||
@staticmethod
|
||||
async def iterate_bg(queue: asyncio.Queue) -> None:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
pass
|
||||
|
||||
|
||||
async def get_stream_context():
|
||||
global stream_context_global
|
||||
if stream_context_global:
|
||||
|
@ -438,8 +398,6 @@ async def start_agent(
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||
|
||||
request_id = structlog.contextvars.get_contextvars().get('request_id')
|
||||
|
||||
return {"agent_run_id": agent_run_id, "status": "running"}
|
||||
|
||||
@router.post("/agent-run/{agent_run_id}/stop")
|
||||
|
@ -718,7 +676,7 @@ async def stream_agent_run(
|
|||
return
|
||||
|
||||
# Create the stream
|
||||
stream = await stream_context.resumable_stream(agent_run_id, lambda: run_agent_background(
|
||||
stream = await stream_context.resumable_stream(agent_run_id, lambda: run_agent_run_stream(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
model_name=model_name,
|
||||
|
@ -1037,8 +995,6 @@ async def initiate_agent_with_files(
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||
|
||||
request_id = structlog.contextvars.get_contextvars().get('request_id')
|
||||
|
||||
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -6,7 +6,7 @@ import sentry
|
|||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, AsyncIterable
|
||||
from services import redis
|
||||
from agent.run import run_agent
|
||||
from utils.logger import logger, structlog
|
||||
|
@ -22,6 +22,43 @@ db = DBConnection()
|
|||
instance_id = "single"
|
||||
|
||||
|
||||
# Create stream broadcaster for multiple consumers
|
||||
class StreamBroadcaster:
|
||||
def __init__(self, source: AsyncIterable[Any]):
|
||||
self.source = source
|
||||
self.queues: List[asyncio.Queue] = []
|
||||
|
||||
def add_consumer(self) -> asyncio.Queue:
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
self.queues.append(q)
|
||||
return q
|
||||
|
||||
async def start(self) -> None:
|
||||
async for chunk in self.source:
|
||||
for q in self.queues:
|
||||
await q.put(chunk)
|
||||
for q in self.queues:
|
||||
await q.put(None) # Sentinel to close consumers
|
||||
|
||||
# Consumer wrapper as an async generator
|
||||
@staticmethod
|
||||
async def queue_to_stream(queue: asyncio.Queue) -> AsyncIterable[Any]:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
# Print consumer task
|
||||
@staticmethod
|
||||
async def iterate_bg(queue: asyncio.Queue) -> None:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
pass
|
||||
|
||||
|
||||
async def initialize():
|
||||
"""Initialize the agent API with resources from the main API."""
|
||||
global db, instance_id, _initialized
|
||||
|
@ -41,7 +78,7 @@ async def check_health(key: str):
|
|||
await redis.set(key, "healthy", ex=redis.REDIS_KEY_TTL)
|
||||
|
||||
|
||||
async def run_agent_background(
|
||||
async def run_agent_run_stream(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str,
|
||||
|
@ -107,22 +144,28 @@ async def run_agent_background(
|
|||
stop_signal_received = False
|
||||
|
||||
stop_redis_key = f"stop_signal:{agent_run_id}"
|
||||
|
||||
|
||||
async def check_for_stop_signal():
|
||||
nonlocal stop_signal_received
|
||||
try:
|
||||
while not stop_signal_received:
|
||||
message = await redis.client.get(stop_redis_key)
|
||||
if message == "STOP":
|
||||
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
||||
logger.info(
|
||||
f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})"
|
||||
)
|
||||
stop_signal_received = True
|
||||
break
|
||||
await asyncio.sleep(0.5) # Short sleep to prevent tight loop
|
||||
await asyncio.sleep(0.5) # Short sleep to prevent tight loop
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
||||
logger.info(
|
||||
f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||
stop_signal_received = True # Stop the run if the checker fails
|
||||
logger.error(
|
||||
f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True
|
||||
)
|
||||
stop_signal_received = True # Stop the run if the checker fails
|
||||
|
||||
asyncio.create_task(check_for_stop_signal())
|
||||
|
||||
|
@ -150,7 +193,9 @@ async def run_agent_background(
|
|||
if stop_signal_received:
|
||||
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||
final_status = "stopped"
|
||||
trace.span(name="agent_run_stopped").end(status_message="agent_run_stopped", level="WARNING")
|
||||
trace.span(name="agent_run_stopped").end(
|
||||
status_message="agent_run_stopped", level="WARNING"
|
||||
)
|
||||
break
|
||||
|
||||
all_responses.append(response) # Keep for DB updates
|
||||
|
@ -158,7 +203,7 @@ async def run_agent_background(
|
|||
yield f"data: {json.dumps(response)}\n\n"
|
||||
else:
|
||||
yield f"data: {response}\n\n"
|
||||
|
||||
|
||||
# Check for agent-signaled completion or error
|
||||
if response.get("type") == "status":
|
||||
status_val = response.get("status")
|
|
@ -249,46 +249,32 @@ Please respond appropriately to this trigger event."""
|
|||
"agent_id": agent_config['agent_id'],
|
||||
"agent_version_id": agent_config.get('current_version_id'),
|
||||
"status": "running",
|
||||
"started_at": datetime.now(timezone.utc).isoformat()
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"metadata": {
|
||||
"model_name": "anthropic/claude-sonnet-4-20250514",
|
||||
"enable_thinking": False,
|
||||
"reasoning_effort": "low",
|
||||
"enable_context_manager": True,
|
||||
"trigger_execution": True,
|
||||
"trigger_variables": trigger_variables
|
||||
}
|
||||
}
|
||||
|
||||
agent_run = await client.table('agent_runs').insert(agent_run_data).execute()
|
||||
agent_run_id = agent_run.data[0]['id']
|
||||
|
||||
# Import and use the existing agent background execution
|
||||
# Register this run in Redis with TTL using trigger executor instance ID
|
||||
instance_id = "trigger_executor"
|
||||
instance_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
try:
|
||||
from run_agent_background import run_agent_background
|
||||
|
||||
# Start agent execution in background
|
||||
run_agent_background.send(
|
||||
agent_run_id=agent_run_id,
|
||||
thread_id=thread_id,
|
||||
instance_id="trigger_executor",
|
||||
project_id=project_id,
|
||||
model_name="anthropic/claude-sonnet-4-20250514",
|
||||
enable_thinking=False,
|
||||
reasoning_effort="low",
|
||||
stream=False,
|
||||
enable_context_manager=True,
|
||||
agent_config=agent_config,
|
||||
is_agent_builder=False,
|
||||
target_agent_id=None,
|
||||
request_id=None
|
||||
)
|
||||
|
||||
logger.info(f"Started background agent execution for trigger (run_id: {agent_run_id})")
|
||||
return agent_run_id
|
||||
|
||||
except ImportError:
|
||||
# Fallback if background execution is not available
|
||||
logger.warning("Background agent execution not available, marking as completed")
|
||||
await client.table('agent_runs').update({
|
||||
"status": "completed",
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
"error": "Background execution not available"
|
||||
}).eq('id', agent_run_id).execute()
|
||||
|
||||
return agent_run_id
|
||||
from services import redis
|
||||
await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||
logger.info(f"Registered trigger agent run in Redis ({instance_key})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register trigger agent run in Redis ({instance_key}): {str(e)}")
|
||||
|
||||
logger.info(f"Created trigger agent run: {agent_run_id}")
|
||||
return agent_run_id
|
||||
|
||||
class TriggerResponseHandler:
|
||||
"""Handles responses back to external services when agents complete."""
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
from utils.logger import logger
|
||||
import run_agent_background
|
||||
from services import redis
|
||||
import asyncio
|
||||
from utils.retry import retry
|
||||
import uuid
|
||||
|
||||
|
||||
async def main():
|
||||
await retry(lambda: redis.initialize_async())
|
||||
key = uuid.uuid4().hex
|
||||
run_agent_background.check_health.send(key)
|
||||
timeout = 20 # seconds
|
||||
elapsed = 0
|
||||
while elapsed < timeout:
|
||||
if await redis.get(key) == "healthy":
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
elapsed += 1
|
||||
|
||||
if elapsed >= timeout:
|
||||
logger.critical("Health check timed out")
|
||||
exit(1)
|
||||
else:
|
||||
logger.critical("Health check passed")
|
||||
await redis.delete(key)
|
||||
await redis.close()
|
||||
exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue