suna/agentpress/api.py

183 lines
8.1 KiB
Python

from fastapi import FastAPI, HTTPException, Query, Path
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import asyncio
import json
from agentpress.db import Database
from agentpress.thread_manager import ThreadManager
from agentpress.tool_registry import ToolRegistry
from agentpress.config import Settings
app = FastAPI(
title="Thread Manager API",
description="API for managing and running threads with LLM integration",
version="1.0.0",
)
db = Database()
manager = ThreadManager(db)
tool_registry = ToolRegistry()
class Message(BaseModel):
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
content: str = Field(..., description="The content of the message")
class RunThreadRequest(BaseModel):
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
model_name: str = Field(..., description="The name of the LLM model to be used")
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message")
additional_message: Optional[Dict[str, Any]] = Field(None, description="Additional message to be appended at the end of the conversation")
hide_tool_msgs: bool = Field(False, description="Whether to hide tool messages in the conversation history")
execute_tools_async: bool = Field(True, description="Whether to execute tools asynchronously")
use_tool_parser: bool = Field(False, description="Whether to use the tool parser for handling tool calls")
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
autonomous_iterations_amount: Optional[int] = Field(None, description="The number of autonomous iterations to perform")
continue_instructions: Optional[str] = Field(None, description="Instructions for continuing the conversation in subsequent iterations")
initializer: Optional[str] = Field(None, description="Name of the initializer function")
pre_iteration: Optional[str] = Field(None, description="Name of the pre-iteration function")
after_iteration: Optional[str] = Field(None, description="Name of the after-iteration function")
finalizer: Optional[str] = Field(None, description="Name of the finalizer function")
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
async def create_thread():
"""
Create a new thread and return its ID.
"""
thread_id = await manager.create_thread()
return {"thread_id": thread_id}
@app.get("/threads/", response_model=List[Dict[str, Any]], summary="Get all threads")
async def get_threads():
"""
Retrieve a list of all threads.
"""
threads = await manager.get_threads()
return [{"thread_id": thread.thread_id, "created_at": thread.created_at} for thread in threads]
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
async def add_message(thread_id: str, message: Message):
"""
Add a new message to the specified thread.
"""
await manager.add_message(thread_id, message.dict())
return {"status": "success"}
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
async def list_messages(thread_id: str):
"""
Retrieve all messages from the specified thread.
"""
messages = await manager.list_messages(thread_id)
return messages
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
async def run_thread(thread_id: str, request: RunThreadRequest):
try:
# Create a new ThreadRun object
thread_run = await manager.create_thread_run(
thread_id,
model_name=request.model_name,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
tool_choice=request.tool_choice,
execute_tools_async=request.execute_tools_async,
system_message=json.dumps(request.system_message),
tools=json.dumps(request.tools),
response_format=json.dumps(request.response_format),
autonomous_iterations_amount=request.autonomous_iterations_amount,
continue_instructions=request.continue_instructions
)
# Run the thread with the created ThreadRun object
result = await manager.run_thread(
thread_id=thread_id,
thread_run=thread_run,
initializer=get_function(request.initializer),
pre_iteration=get_function(request.pre_iteration),
after_iteration=get_function(request.after_iteration),
finalizer=get_function(request.finalizer),
**request.dict(exclude={'initializer', 'pre_iteration', 'after_iteration', 'finalizer'})
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def get_function(function_name: Optional[str]):
if function_name is None:
return None
# Implement a way to get the function by name, e.g., from a predefined dictionary of functions
# For now, we'll return None
return None
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
async def get_tools():
"""
Retrieve a list of all available tools and their schemas.
"""
tools = tool_registry.get_all_tools()
if not tools:
print("No tools found in the registry") # Debug print
return {
name: {
"name": name,
"description": tool_info['schema']['function']['description'],
"schema": tool_info['schema']
}
for name, tool_info in tools.items()
}
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
async def get_run(
thread_id: str = Path(..., description="The ID of the thread that was run"),
run_id: str = Path(..., description="The ID of the run to retrieve")
):
run = await manager.get_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
async def cancel_run(
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
run_id: str = Path(..., description="The ID of the run to cancel")
):
"""
Cancels a run that is in_progress.
"""
run = await manager.cancel_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
async def list_runs(
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
):
runs = await manager.list_runs(thread_id, limit)
return runs
@app.post("/threads/{thread_id}/runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop a thread run")
async def stop_thread_run(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the run to stop")
):
"""
Stops a thread run that is in progress.
"""
run = await manager.stop_thread_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found or already completed/stopped")
return run
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)