suna/agentpress/api/api.py

255 lines
9.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager
import asyncio
import uvicorn
import logging
from agentpress.api.ws import ws_manager
from agentpress.api.api_factory import (
app as thread_task_app,
register_thread_task_api,
discover_tasks,
thread_manager as task_thread_manager
)
# from agentpress.api_factory import app as api_app, discover_tasks
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global managers
thread_manager: Optional[ThreadManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for FastAPI application."""
# Startup
global thread_manager
thread_manager = ThreadManager()
# Share thread_manager with task API
global task_thread_manager
task_thread_manager = thread_manager
# Wait for DB initialization
db = thread_manager.db
if db._initialization_task:
await db._initialization_task
# Run task discovery during startup
discover_tasks()
yield
# Shutdown
# Add any cleanup code here if needed
# Create FastAPI app
app = FastAPI(title="AgentPress API", lifespan=lifespan)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# # Import and mount the API Factory app
# try:
# # Run task discovery
# # discover_tasks()
# logger.info("Task discovery completed")
# # Mount the API Factory app at /tasks instead of root
# app.mount("/tasks", api_app)
# logger.info("Mounted API Factory app at /tasks")
# except Exception as e:
# logger.error(f"Error setting up API Factory: {e}")
# raise
# Pydantic models for request/response validation
class MessageCreate(BaseModel):
"""Model for creating messages in a thread."""
message_data: Union[str, Dict[str, Any]]
images: Optional[List[Dict[str, Any]]] = None
include_in_llm_message_history: bool = True
message_type: Optional[str] = None
# REST API Endpoints
@app.post("/threads", response_model=dict, status_code=201)
async def create_thread():
"""Create a new thread."""
thread_id = await thread_manager.create_thread()
return {"thread_id": thread_id}
@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201)
async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks):
"""Create a new message in a thread."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
try:
await thread_manager.add_message(
thread_id=thread_id,
message_data=message.message_data,
images=message.images,
include_in_llm_message_history=message.include_in_llm_message_history,
message_type=message.message_type
)
# Broadcast to WebSocket connections
background_tasks.add_task(
ws_manager.broadcast_to_thread,
thread_id,
{"type": "message_created", "message": message.dict()}
)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# TODO: BROKEN FOR SOME REASON RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE
@app.get("/threads/{thread_id}/llm_history_messages")
async def get_thread_llm_messages(
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
):
"""Get messages from a thread with filtering options."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
messages = await thread_manager.get_llm_history_messages(
thread_id=thread_id,
hide_tool_msgs=hide_tool_msgs,
only_latest_assistant=only_latest_assistant,
)
return {"messages": messages}
@app.get("/threads/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
message_types: Optional[List[str]] = None,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
before_timestamp: Optional[str] = None,
after_timestamp: Optional[str] = None,
include_in_llm_message_history: Optional[bool] = None,
order: str = "asc"
):
"""
Get messages from a thread with comprehensive filtering options.
Args:
thread_id: Thread identifier
message_types: Optional list of message types to filter by
limit: Maximum number of messages to return (default: 50)
offset: Number of messages to skip for pagination
before_timestamp: Optional filter for messages before timestamp
after_timestamp: Optional filter for messages after timestamp
include_in_llm_message_history: Optional filter for LLM history inclusion
order: Sort order - "asc" or "desc"
"""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
try:
messages = await thread_manager.get_messages(
thread_id=thread_id,
message_types=message_types,
limit=limit,
offset=offset,
before_timestamp=before_timestamp,
after_timestamp=after_timestamp,
include_in_llm_message_history=include_in_llm_message_history,
order=order
)
return {"messages": messages}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling.
# WebSocket Endpoint
@app.websocket("/threads/{thread_id}")
async def websocket_endpoint(
websocket: WebSocket,
thread_id: str,
message_types: Optional[List[str]] = None,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
before_timestamp: Optional[str] = None,
after_timestamp: Optional[str] = None,
include_in_llm_message_history: Optional[bool] = None,
order: str = "desc"
):
"""
WebSocket endpoint for real-time thread updates with filtering and pagination.
Query Parameters:
message_types: Optional list of message types to filter by
limit: Maximum number of messages to return (default: 50)
offset: Number of messages to skip (for pagination)
before_timestamp: Optional timestamp to filter messages before
after_timestamp: Optional timestamp to filter messages after
include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
order: Sort order - "asc" or "desc" (default: desc)
"""
try:
if not await thread_manager.thread_exists(thread_id):
await websocket.close(code=4004, reason="Thread not found")
return
await ws_manager.connect(websocket, thread_id)
while True:
try:
# Get messages with all filters
result = await thread_manager.get_messages(
thread_id=thread_id,
message_types=message_types,
limit=limit,
offset=offset,
before_timestamp=before_timestamp,
after_timestamp=after_timestamp,
include_in_llm_message_history=include_in_llm_message_history,
order=order
)
# Send messages and pagination info
await websocket.send_json({
"type": "messages",
"data": result
})
# Poll every second
await asyncio.sleep(1)
except WebSocketDisconnect:
ws_manager.disconnect(websocket, thread_id)
break
except Exception as e:
logging.error(f"WebSocket error: {e}")
await websocket.send_json({
"type": "error",
"data": str(e)
})
ws_manager.disconnect(websocket, thread_id)
break
except Exception as e:
logging.error(f"WebSocket connection error: {e}")
try:
await websocket.close(code=1011, reason=str(e))
except:
pass
# Update the mounting of thread_task_app
app.mount("/tasks", thread_task_app)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)