suna/agentpress/api.py

226 lines
10 KiB
Python

from fastapi import FastAPI, HTTPException, Query, Path
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import asyncio
from agentpress.db import Database
from agentpress.thread_manager import ThreadManager
from agentpress.tool_registry import ToolRegistry
from agentpress.config import Settings
app = FastAPI(
title="Thread Manager API",
description="API for managing and running threads with LLM integration",
version="1.0.0",
)
db = Database()
manager = ThreadManager(db)
tool_registry = ToolRegistry()
class Message(BaseModel):
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
content: str = Field(..., description="The content of the message")
class RunThreadRequest(BaseModel):
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
model_name: str = Field(..., description="The name of the LLM model to be used")
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message. This is useful for modifying the behavior on a per-run basis without overriding other instructions.")
additional_message: Optional[Dict[str, Any]] = Field(None, description="Additional message to be appended at the end of the conversation. This is useful for modifying the behavior on a per-run basis without overriding other instructions.")
hide_tool_msgs: bool = Field(False, description="Whether to hide tool messages in the conversation history")
execute_tools_async: bool = Field(True, description="Whether to execute tools asynchronously")
use_tool_parser: bool = Field(False, description="Whether to use the tool parser for handling tool calls")
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
class RunThreadAgentRequest(BaseModel):
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
model_name: str = Field(..., description="The name of the LLM model to be used")
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message")
autonomous_iterations_amount: int = Field(3, description="The number of autonomous iterations for the agent to perform")
continue_instructions: str = Field(..., description="Instructions for continuing the conversation in subsequent iterations")
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
async def create_thread():
"""
Create a new thread and return its ID.
"""
thread_id = await manager.create_thread()
return {"thread_id": thread_id}
@app.get("/threads/", response_model=List[Dict[str, Any]], summary="Get all threads")
async def get_threads():
"""
Retrieve a list of all threads.
"""
threads = await manager.get_threads()
return [{"thread_id": thread.thread_id, "created_at": thread.created_at} for thread in threads]
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
async def add_message(thread_id: str, message: Message):
"""
Add a new message to the specified thread.
"""
await manager.add_message(thread_id, message.dict())
return {"status": "success"}
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
async def list_messages(thread_id: str):
"""
Retrieve all messages from the specified thread.
"""
messages = await manager.list_messages(thread_id)
return messages
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
async def run_thread(thread_id: str, request: RunThreadRequest):
try:
result = await manager.run_thread(thread_id, **request.dict())
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/threads/{thread_id}/run/status/", response_model=Dict[str, Any], summary="Get thread run status")
async def get_thread_run_status(thread_id: str):
"""
Retrieve the status of the latest run for the specified thread.
"""
latest_thread_run = await manager.get_latest_thread_run(thread_id)
if latest_thread_run:
return latest_thread_run
else:
return {"status": "No runs found for this thread."}
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
async def get_tools():
"""
Retrieve a list of all available tools and their schemas.
"""
tools = tool_registry.get_all_tools()
if not tools:
print("No tools found in the registry") # Debug print
return {
name: {
"name": name,
"description": tool_info['schema']['function']['description'],
"schema": tool_info['schema']
}
for name, tool_info in tools.items()
}
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
async def get_run(
thread_id: str = Path(..., description="The ID of the thread that was run"),
run_id: str = Path(..., description="The ID of the run to retrieve")
):
run = await manager.get_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
async def cancel_run(
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
run_id: str = Path(..., description="The ID of the run to cancel")
):
"""
Cancels a run that is in_progress.
"""
run = await manager.cancel_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
async def list_runs(
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
):
runs = await manager.list_runs(thread_id, limit)
return runs
@app.post("/threads/{thread_id}/run_agent/", response_model=Dict[str, Any], summary="Run a thread agent")
async def run_thread_agent(thread_id: str, run_request: RunThreadAgentRequest):
try:
result = await manager.run_thread_agent(thread_id, **run_request.dict())
return result
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/threads/{thread_id}/agent_runs", response_model=List[Dict[str, Any]], summary="List agent runs")
async def list_agent_runs(
thread_id: str = Path(..., description="The ID of the thread the agent runs belong to"),
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
):
"""
Retrieve a list of agent runs for the specified thread.
"""
agent_runs = await manager.list_agent_runs(thread_id, limit)
return agent_runs
@app.post("/threads/{thread_id}/runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop a thread run")
async def stop_thread_run(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the run to stop")
):
"""
Stops a thread run that is in progress.
"""
run = await manager.stop_thread_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found or already completed/stopped")
return run
@app.post("/threads/{thread_id}/agent_runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop an agent run")
async def stop_agent_run(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the agent run to stop")
):
"""
Stops an agent run that is in progress and all its associated thread runs.
"""
run = await manager.stop_agent_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Agent run not found or already completed/stopped")
return run
@app.get("/threads/{thread_id}/runs/{run_id}/status", response_model=Dict[str, Any], summary="Get thread run status")
async def get_thread_run_status(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the run")
):
"""
Retrieves the status and details of a specific thread run.
"""
run = await manager.get_thread_run_status(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/threads/{thread_id}/agent_runs/{run_id}/status", response_model=Dict[str, Any], summary="Get agent run status")
async def get_agent_run_status(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the agent run")
):
"""
Retrieves the status and details of a specific agent run.
"""
run = await manager.get_agent_run_status(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Agent run not found")
return run
# Add more endpoints as needed for production use
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)