mirror of https://github.com/kortix-ai/suna.git
373 lines
14 KiB
Python
373 lines
14 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, Request
|
|
from fastapi.responses import StreamingResponse
|
|
import asyncio
|
|
import json
|
|
import traceback
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
from typing import Optional
|
|
import jwt
|
|
|
|
from agentpress.thread_manager import ThreadManager
|
|
from services.supabase import DBConnection
|
|
from services import redis
|
|
from agent.run import run_agent
|
|
from services.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access, verify_agent_run_access
|
|
|
|
# Initialize shared resources
|
|
router = APIRouter()
|
|
thread_manager = None
|
|
db = None
|
|
instance_id = None
|
|
|
|
def initialize(
|
|
_thread_manager: ThreadManager,
|
|
_db: DBConnection
|
|
):
|
|
"""Initialize the agent API with resources from the main API."""
|
|
global thread_manager, db, instance_id
|
|
thread_manager = _thread_manager
|
|
db = _db
|
|
|
|
# Generate instance ID
|
|
instance_id = str(uuid.uuid4())[:8]
|
|
|
|
# Note: Redis will be initialized in the lifespan function in api.py
|
|
|
|
async def cleanup():
|
|
"""Clean up resources and stop running agents on shutdown."""
|
|
# Get Redis client
|
|
redis_client = await redis.get_client()
|
|
|
|
# Use the instance_id to find and clean up this instance's keys
|
|
running_keys = await redis_client.keys(f"active_run:{instance_id}:*")
|
|
|
|
for key in running_keys:
|
|
agent_run_id = key.split(":")[-1]
|
|
await stop_agent_run(agent_run_id)
|
|
|
|
# Close Redis connection
|
|
await redis.close()
|
|
|
|
async def stop_agent_run(agent_run_id: str):
|
|
"""Update database and publish stop signal to Redis."""
|
|
client = await db.client
|
|
redis_client = await redis.get_client()
|
|
|
|
# Update the agent run status to stopped
|
|
await client.table('agent_runs').update({
|
|
"status": "stopped",
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
}).eq("id", agent_run_id).execute()
|
|
|
|
# Publish stop signal to the agent run channel as a string
|
|
await redis_client.publish(f"agent_run:{agent_run_id}:control", "STOP")
|
|
|
|
async def restore_running_agent_runs():
|
|
"""Restore any agent runs that were still marked as running in the database."""
|
|
client = await db.client
|
|
running_agent_runs = await client.table('agent_runs').select('*').eq("status", "running").execute()
|
|
|
|
for run in running_agent_runs.data:
|
|
await client.table('agent_runs').update({
|
|
"status": "failed",
|
|
"error": "Server restarted while agent was running",
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
}).eq("id", run['id']).execute()
|
|
|
|
@router.post("/thread/{thread_id}/agent/start")
|
|
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
|
"""Start an agent for a specific thread in the background."""
|
|
client = await db.client
|
|
redis_client = await redis.get_client()
|
|
|
|
# Verify user has access to this thread
|
|
await verify_thread_access(client, thread_id, user_id)
|
|
|
|
# Create a new agent run
|
|
agent_run = await client.table('agent_runs').insert({
|
|
"thread_id": thread_id,
|
|
"status": "running",
|
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
"responses": "[]" # Initialize with empty array
|
|
}).execute()
|
|
|
|
agent_run_id = agent_run.data[0]['id']
|
|
|
|
# Register this run in Redis with TTL
|
|
await redis_client.set(
|
|
f"active_run:{instance_id}:{agent_run_id}",
|
|
"running",
|
|
ex=redis.REDIS_KEY_TTL
|
|
)
|
|
|
|
# Run the agent in the background
|
|
task = asyncio.create_task(
|
|
run_agent_background(agent_run_id, thread_id, instance_id)
|
|
)
|
|
|
|
# Set a callback to clean up when task is done
|
|
task.add_done_callback(
|
|
lambda _: asyncio.create_task(
|
|
_cleanup_agent_run(agent_run_id)
|
|
)
|
|
)
|
|
|
|
return {"agent_run_id": agent_run_id, "status": "running"}
|
|
|
|
async def _cleanup_agent_run(agent_run_id: str):
|
|
"""Clean up Redis keys when an agent run is done."""
|
|
redis_client = await redis.get_client()
|
|
await redis_client.delete(f"active_run:{instance_id}:{agent_run_id}")
|
|
|
|
@router.post("/agent-run/{agent_run_id}/stop")
|
|
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
|
|
"""Stop a running agent."""
|
|
client = await db.client
|
|
|
|
# Verify user has access to the agent run
|
|
await verify_agent_run_access(client, agent_run_id, user_id)
|
|
|
|
# Stop the agent run
|
|
await stop_agent_run(agent_run_id)
|
|
|
|
return {"status": "stopped"}
|
|
|
|
@router.get("/agent-run/{agent_run_id}/stream")
|
|
async def stream_agent_run(
|
|
agent_run_id: str,
|
|
token: Optional[str] = None,
|
|
request: Request = None
|
|
):
|
|
"""Stream the responses of an agent run from where they left off."""
|
|
client = await db.client
|
|
redis_client = await redis.get_client()
|
|
|
|
# Get user ID using the streaming auth function
|
|
user_id = await get_user_id_from_stream_auth(request, token)
|
|
|
|
# Verify user has access to the agent run and get run data
|
|
agent_run_data = await verify_agent_run_access(client, agent_run_id, user_id)
|
|
|
|
responses = json.loads(agent_run_data['responses']) if agent_run_data['responses'] else []
|
|
|
|
# Create a pubsub to listen for new responses
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.subscribe(f"agent_run:{agent_run_id}:responses")
|
|
|
|
# Define the streaming generator
|
|
async def event_generator():
|
|
try:
|
|
# First send any existing responses
|
|
for response in responses:
|
|
yield f"data: {json.dumps(response)}\n\n"
|
|
|
|
# Then stream new responses
|
|
while True:
|
|
message = await pubsub.get_message(timeout=0.1) # Reduced timeout for faster response
|
|
if message and message["type"] == "message":
|
|
data = message["data"]
|
|
|
|
# Check if this is the end marker
|
|
end_stream_marker = "END_STREAM"
|
|
if data == end_stream_marker or data == end_stream_marker.encode('utf-8'):
|
|
break
|
|
|
|
# Handle both string and bytes data
|
|
if isinstance(data, bytes):
|
|
data_str = data.decode('utf-8')
|
|
else:
|
|
data_str = str(data)
|
|
|
|
# Don't add extra formatting to already JSON-formatted data
|
|
yield f"data: {data_str}\n\n"
|
|
|
|
# Check if agent is still running
|
|
current_run = await client.table('agent_runs').select('status').eq('id', agent_run_id).execute()
|
|
if not current_run.data or current_run.data[0]['status'] != 'running':
|
|
# Send final status update
|
|
yield f"data: {json.dumps({'type': 'status', 'status': current_run.data[0]['status'] if current_run.data else 'unknown'})}\n\n"
|
|
break
|
|
|
|
await asyncio.sleep(0.01) # Minimal sleep to prevent CPU spinning
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
await pubsub.unsubscribe()
|
|
|
|
# Return a StreamingResponse with the correct headers for SSE
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
"Content-Type": "text/event-stream",
|
|
"Access-Control-Allow-Origin": "*" # Add CORS header for EventSource
|
|
}
|
|
)
|
|
|
|
@router.get("/thread/{thread_id}/agent-runs")
|
|
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
|
"""Get all agent runs for a thread."""
|
|
client = await db.client
|
|
|
|
# Verify user has access to this thread
|
|
await verify_thread_access(client, thread_id, user_id)
|
|
|
|
agent_runs = await client.table('agent_runs').select('*').eq("thread_id", thread_id).execute()
|
|
return {"agent_runs": agent_runs.data}
|
|
|
|
@router.get("/agent-run/{agent_run_id}")
|
|
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
|
|
"""Get agent run status and responses."""
|
|
client = await db.client
|
|
|
|
# Verify user has access to the agent run and get run data
|
|
agent_run_data = await verify_agent_run_access(client, agent_run_id, user_id)
|
|
|
|
responses = json.loads(agent_run_data['responses']) if agent_run_data['responses'] else []
|
|
|
|
return {
|
|
"id": agent_run_data['id'],
|
|
"threadId": agent_run_data['thread_id'],
|
|
"status": agent_run_data['status'],
|
|
"startedAt": agent_run_data['started_at'],
|
|
"completedAt": agent_run_data['completed_at'],
|
|
"responses": responses,
|
|
"error": agent_run_data['error']
|
|
}
|
|
|
|
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str):
|
|
"""Run the agent in the background and store responses."""
|
|
client = await db.client
|
|
redis_client = await redis.get_client()
|
|
|
|
# Create a buffer to store response chunks
|
|
responses = []
|
|
batch = []
|
|
last_db_update = datetime.now(timezone.utc)
|
|
|
|
# Create a pubsub to listen for control messages
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.subscribe(f"agent_run:{agent_run_id}:control")
|
|
|
|
# Start a background task to check for stop signals
|
|
stop_signal_received = False
|
|
|
|
async def check_for_stop_signal():
|
|
nonlocal stop_signal_received
|
|
while True:
|
|
message = await pubsub.get_message(timeout=0.1) # Reduced timeout
|
|
if message and message["type"] == "message":
|
|
stop_signal = "STOP"
|
|
if message["data"] == stop_signal or message["data"] == stop_signal.encode('utf-8'):
|
|
stop_signal_received = True
|
|
break
|
|
await asyncio.sleep(0.01) # Minimal sleep
|
|
if stop_signal_received:
|
|
break
|
|
|
|
# Start the stop signal checker
|
|
stop_checker = asyncio.create_task(check_for_stop_signal())
|
|
|
|
try:
|
|
# Run the agent and collect responses
|
|
agent_gen = run_agent(thread_id, stream=True,
|
|
thread_manager=thread_manager)
|
|
|
|
async for response in agent_gen:
|
|
# Check if stop signal received
|
|
if stop_signal_received:
|
|
break
|
|
|
|
# Format the response properly
|
|
formatted_response = None
|
|
|
|
# Handle different types of responses
|
|
if isinstance(response, str):
|
|
formatted_response = {"type": "content", "content": response}
|
|
elif isinstance(response, dict):
|
|
if "type" in response:
|
|
formatted_response = response
|
|
else:
|
|
formatted_response = {"type": "content", **response}
|
|
else:
|
|
formatted_response = {"type": "content", "content": str(response)}
|
|
|
|
# Add response to batch and responses list
|
|
responses.append(formatted_response)
|
|
batch.append(formatted_response)
|
|
|
|
# Immediately publish the response to Redis
|
|
await redis_client.publish(
|
|
f"agent_run:{agent_run_id}:responses",
|
|
json.dumps(formatted_response)
|
|
)
|
|
|
|
# Update database less frequently to reduce overhead
|
|
now = datetime.now(timezone.utc)
|
|
if (now - last_db_update).total_seconds() >= 2.0 and batch: # Increased interval
|
|
await client.table('agent_runs').update({
|
|
"responses": json.dumps(responses)
|
|
}).eq("id", agent_run_id).execute()
|
|
|
|
batch = []
|
|
last_db_update = now
|
|
|
|
# No sleep needed here - let it run as fast as possible
|
|
|
|
# Final update to database with all responses
|
|
if batch:
|
|
await client.table('agent_runs').update({
|
|
"responses": json.dumps(responses)
|
|
}).eq("id", agent_run_id).execute()
|
|
|
|
# Signal all done if we weren't stopped
|
|
if not stop_signal_received:
|
|
await client.table('agent_runs').update({
|
|
"status": "completed",
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
}).eq("id", agent_run_id).execute()
|
|
|
|
# Send END_STREAM signal
|
|
end_stream_marker = "END_STREAM"
|
|
await redis_client.publish(
|
|
f"agent_run:{agent_run_id}:responses",
|
|
end_stream_marker
|
|
)
|
|
|
|
except Exception as e:
|
|
# Log the error and update the agent run
|
|
error_message = str(e)
|
|
traceback_str = traceback.format_exc()
|
|
print(f"Error in agent run {agent_run_id}: {error_message}\n{traceback_str}")
|
|
|
|
# Update the agent run with the error
|
|
await client.table('agent_runs').update({
|
|
"status": "failed",
|
|
"error": f"{error_message}\n{traceback_str}",
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
}).eq("id", agent_run_id).execute()
|
|
|
|
# Send END_STREAM signal
|
|
end_stream_marker = "END_STREAM"
|
|
await redis_client.publish(
|
|
f"agent_run:{agent_run_id}:responses",
|
|
end_stream_marker
|
|
)
|
|
finally:
|
|
# Ensure we always clean up the pubsub and stop checker
|
|
stop_checker.cancel()
|
|
await pubsub.unsubscribe()
|
|
|
|
# Make sure we mark the run as completed or failed if it was still running
|
|
current_run = await client.table('agent_runs').select('status').eq("id", agent_run_id).execute()
|
|
if current_run.data and current_run.data[0]['status'] == 'running':
|
|
await client.table('agent_runs').update({
|
|
"status": "failed" if stop_signal_received else "completed",
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
}).eq("id", agent_run_id).execute()
|