mirror of https://github.com/kortix-ai/suna.git
255 lines
9.0 KiB
Python
255 lines
9.0 KiB
Python
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from typing import Optional, List, Dict, Any, Union
|
||
from pydantic import BaseModel
|
||
from agentpress.thread_manager import ThreadManager
|
||
import asyncio
|
||
import uvicorn
|
||
import logging
|
||
from agentpress.api.ws import ws_manager
|
||
from agentpress.api.api_factory import (
|
||
app as thread_task_app,
|
||
register_thread_task_api,
|
||
discover_tasks,
|
||
thread_manager as task_thread_manager
|
||
)
|
||
# from agentpress.api_factory import app as api_app, discover_tasks
|
||
|
||
# Configure logging
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Global managers
|
||
thread_manager: Optional[ThreadManager] = None
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""Lifespan context manager for FastAPI application."""
|
||
# Startup
|
||
global thread_manager
|
||
thread_manager = ThreadManager()
|
||
|
||
# Share thread_manager with task API
|
||
global task_thread_manager
|
||
task_thread_manager = thread_manager
|
||
|
||
# Wait for DB initialization
|
||
db = thread_manager.db
|
||
if db._initialization_task:
|
||
await db._initialization_task
|
||
|
||
# Run task discovery during startup
|
||
discover_tasks()
|
||
|
||
yield
|
||
|
||
# Shutdown
|
||
# Add any cleanup code here if needed
|
||
|
||
# Create FastAPI app
|
||
app = FastAPI(title="AgentPress API", lifespan=lifespan)
|
||
|
||
# Enable CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# # Import and mount the API Factory app
|
||
# try:
|
||
# # Run task discovery
|
||
# # discover_tasks()
|
||
# logger.info("Task discovery completed")
|
||
|
||
# # Mount the API Factory app at /tasks instead of root
|
||
# app.mount("/tasks", api_app)
|
||
# logger.info("Mounted API Factory app at /tasks")
|
||
# except Exception as e:
|
||
# logger.error(f"Error setting up API Factory: {e}")
|
||
# raise
|
||
|
||
# Pydantic models for request/response validation
|
||
class MessageCreate(BaseModel):
|
||
"""Model for creating messages in a thread."""
|
||
message_data: Union[str, Dict[str, Any]]
|
||
images: Optional[List[Dict[str, Any]]] = None
|
||
include_in_llm_message_history: bool = True
|
||
message_type: Optional[str] = None
|
||
|
||
# REST API Endpoints
|
||
@app.post("/threads", response_model=dict, status_code=201)
|
||
async def create_thread():
|
||
"""Create a new thread."""
|
||
thread_id = await thread_manager.create_thread()
|
||
return {"thread_id": thread_id}
|
||
|
||
@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201)
|
||
async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks):
|
||
"""Create a new message in a thread."""
|
||
if not await thread_manager.thread_exists(thread_id):
|
||
raise HTTPException(status_code=404, detail="Thread not found")
|
||
|
||
try:
|
||
await thread_manager.add_message(
|
||
thread_id=thread_id,
|
||
message_data=message.message_data,
|
||
images=message.images,
|
||
include_in_llm_message_history=message.include_in_llm_message_history,
|
||
message_type=message.message_type
|
||
)
|
||
|
||
# Broadcast to WebSocket connections
|
||
background_tasks.add_task(
|
||
ws_manager.broadcast_to_thread,
|
||
thread_id,
|
||
{"type": "message_created", "message": message.dict()}
|
||
)
|
||
return {"status": "success"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
|
||
# TODO: BROKEN FOR SOME REASON – RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE
|
||
@app.get("/threads/{thread_id}/llm_history_messages")
|
||
async def get_thread_llm_messages(
|
||
thread_id: str,
|
||
hide_tool_msgs: bool = False,
|
||
only_latest_assistant: bool = False,
|
||
):
|
||
"""Get messages from a thread with filtering options."""
|
||
if not await thread_manager.thread_exists(thread_id):
|
||
raise HTTPException(status_code=404, detail="Thread not found")
|
||
|
||
messages = await thread_manager.get_llm_history_messages(
|
||
thread_id=thread_id,
|
||
hide_tool_msgs=hide_tool_msgs,
|
||
only_latest_assistant=only_latest_assistant,
|
||
)
|
||
return {"messages": messages}
|
||
|
||
@app.get("/threads/{thread_id}/messages")
|
||
async def get_thread_messages(
|
||
thread_id: str,
|
||
message_types: Optional[List[str]] = None,
|
||
limit: Optional[int] = 50,
|
||
offset: Optional[int] = 0,
|
||
before_timestamp: Optional[str] = None,
|
||
after_timestamp: Optional[str] = None,
|
||
include_in_llm_message_history: Optional[bool] = None,
|
||
order: str = "asc"
|
||
):
|
||
"""
|
||
Get messages from a thread with comprehensive filtering options.
|
||
|
||
Args:
|
||
thread_id: Thread identifier
|
||
message_types: Optional list of message types to filter by
|
||
limit: Maximum number of messages to return (default: 50)
|
||
offset: Number of messages to skip for pagination
|
||
before_timestamp: Optional filter for messages before timestamp
|
||
after_timestamp: Optional filter for messages after timestamp
|
||
include_in_llm_message_history: Optional filter for LLM history inclusion
|
||
order: Sort order - "asc" or "desc"
|
||
"""
|
||
if not await thread_manager.thread_exists(thread_id):
|
||
raise HTTPException(status_code=404, detail="Thread not found")
|
||
|
||
try:
|
||
messages = await thread_manager.get_messages(
|
||
thread_id=thread_id,
|
||
message_types=message_types,
|
||
limit=limit,
|
||
offset=offset,
|
||
before_timestamp=before_timestamp,
|
||
after_timestamp=after_timestamp,
|
||
include_in_llm_message_history=include_in_llm_message_history,
|
||
order=order
|
||
)
|
||
return {"messages": messages}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
|
||
# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling.
|
||
|
||
# WebSocket Endpoint
|
||
@app.websocket("/threads/{thread_id}")
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
thread_id: str,
|
||
message_types: Optional[List[str]] = None,
|
||
limit: Optional[int] = 50,
|
||
offset: Optional[int] = 0,
|
||
before_timestamp: Optional[str] = None,
|
||
after_timestamp: Optional[str] = None,
|
||
include_in_llm_message_history: Optional[bool] = None,
|
||
order: str = "desc"
|
||
):
|
||
"""
|
||
WebSocket endpoint for real-time thread updates with filtering and pagination.
|
||
|
||
Query Parameters:
|
||
message_types: Optional list of message types to filter by
|
||
limit: Maximum number of messages to return (default: 50)
|
||
offset: Number of messages to skip (for pagination)
|
||
before_timestamp: Optional timestamp to filter messages before
|
||
after_timestamp: Optional timestamp to filter messages after
|
||
include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
|
||
order: Sort order - "asc" or "desc" (default: desc)
|
||
"""
|
||
try:
|
||
if not await thread_manager.thread_exists(thread_id):
|
||
await websocket.close(code=4004, reason="Thread not found")
|
||
return
|
||
|
||
await ws_manager.connect(websocket, thread_id)
|
||
|
||
while True:
|
||
try:
|
||
# Get messages with all filters
|
||
result = await thread_manager.get_messages(
|
||
thread_id=thread_id,
|
||
message_types=message_types,
|
||
limit=limit,
|
||
offset=offset,
|
||
before_timestamp=before_timestamp,
|
||
after_timestamp=after_timestamp,
|
||
include_in_llm_message_history=include_in_llm_message_history,
|
||
order=order
|
||
)
|
||
|
||
# Send messages and pagination info
|
||
await websocket.send_json({
|
||
"type": "messages",
|
||
"data": result
|
||
})
|
||
|
||
# Poll every second
|
||
await asyncio.sleep(1)
|
||
|
||
except WebSocketDisconnect:
|
||
ws_manager.disconnect(websocket, thread_id)
|
||
break
|
||
except Exception as e:
|
||
logging.error(f"WebSocket error: {e}")
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"data": str(e)
|
||
})
|
||
ws_manager.disconnect(websocket, thread_id)
|
||
break
|
||
|
||
except Exception as e:
|
||
logging.error(f"WebSocket connection error: {e}")
|
||
try:
|
||
await websocket.close(code=1011, reason=str(e))
|
||
except:
|
||
pass
|
||
|
||
# Update the mounting of thread_task_app
|
||
app.mount("/tasks", thread_task_app)
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |