mirror of https://github.com/kortix-ai/suna.git
167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Dict, Any, Optional
|
|
import asyncio
|
|
|
|
from core.db import Database
|
|
from core.thread_manager import ThreadManager
|
|
from core.agent_manager import AgentManager
|
|
from core.tools.tool_registry import ToolRegistry
|
|
from core.config import Settings
|
|
|
|
app = FastAPI()
|
|
db = Database()
|
|
manager = ThreadManager(db)
|
|
tool_registry = ToolRegistry() # Initialize here
|
|
agent_manager = AgentManager(db)
|
|
|
|
# Pydantic models for request and response bodies
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class RunThreadRequest(BaseModel):
|
|
system_message: Dict[str, Any]
|
|
model_name: str
|
|
temperature: float = 0.5
|
|
max_tokens: Optional[int] = 500
|
|
tools: Optional[List[str]] = None
|
|
tool_choice: str = "required"
|
|
additional_instructions: Optional[str] = None
|
|
stream: bool = False
|
|
|
|
class Agent(BaseModel):
|
|
name: str
|
|
model: str
|
|
system_prompt: str
|
|
selected_tools: List[str]
|
|
temperature: float = 0.5
|
|
|
|
@app.post("/threads/")
|
|
async def create_thread():
|
|
thread_id = await manager.create_thread()
|
|
return {"thread_id": thread_id}
|
|
|
|
@app.get("/threads/")
|
|
async def get_threads():
|
|
threads = await manager.get_threads()
|
|
return [{"thread_id": thread.thread_id} for thread in threads]
|
|
|
|
@app.post("/threads/{thread_id}/messages/")
|
|
async def add_message(thread_id: int, message: Message):
|
|
await manager.add_message(thread_id, message.dict())
|
|
return {"status": "success"}
|
|
|
|
@app.get("/threads/{thread_id}/messages/")
|
|
async def list_messages(thread_id: int):
|
|
messages = await manager.list_messages(thread_id)
|
|
return messages
|
|
|
|
@app.post("/threads/{thread_id}/run/")
|
|
async def run_thread(thread_id: int, request: Dict[str, Any]):
|
|
if 'agent_id' in request:
|
|
# Agent-based run
|
|
response_gen = manager.run_thread(
|
|
thread_id=thread_id,
|
|
agent_id=request['agent_id'],
|
|
additional_instructions=request.get('additional_instructions'),
|
|
stream=request.get('stream', False)
|
|
)
|
|
else:
|
|
# Manual configuration run
|
|
response_gen = manager.run_thread(
|
|
thread_id=thread_id,
|
|
system_message=request['system_message'],
|
|
model_name=request['model_name'],
|
|
temperature=request['temperature'],
|
|
max_tokens=request['max_tokens'],
|
|
tools=request['tools'],
|
|
additional_instructions=request.get('additional_instructions'),
|
|
stream=request.get('stream', False)
|
|
)
|
|
|
|
if request.get('stream', False):
|
|
raise HTTPException(status_code=501, detail="Streaming is not supported via this endpoint.")
|
|
else:
|
|
response = []
|
|
async for chunk in response_gen:
|
|
response.append(chunk)
|
|
return {"response": response}
|
|
|
|
@app.get("/threads/{thread_id}/run/status/")
|
|
async def get_thread_run_status(thread_id: int):
|
|
try:
|
|
latest_thread_run = await manager.get_latest_thread_run(thread_id)
|
|
if latest_thread_run:
|
|
return {
|
|
"status": latest_thread_run.status,
|
|
"error_message": latest_thread_run.error_message
|
|
}
|
|
else:
|
|
return {"status": "No runs found for this thread."}
|
|
except AttributeError:
|
|
return {"status": "Error", "message": "Unable to retrieve thread run status."}
|
|
|
|
@app.get("/tools/")
|
|
async def get_tools():
|
|
tools = tool_registry.get_all_tools()
|
|
if not tools:
|
|
print("No tools found in the registry") # Debug print
|
|
return {
|
|
name: {
|
|
"name": name,
|
|
"description": tool_info['schema']['function']['description'],
|
|
"schema": tool_info['schema']
|
|
}
|
|
for name, tool_info in tools.items()
|
|
}
|
|
|
|
@app.post("/agents/")
|
|
async def create_agent(agent: Agent):
|
|
agent_id = await agent_manager.create_agent(**agent.dict())
|
|
return {"agent_id": agent_id}
|
|
|
|
@app.get("/agents/")
|
|
async def list_agents():
|
|
agents = await agent_manager.list_agents()
|
|
return [
|
|
{
|
|
"id": agent.id,
|
|
"name": agent.name,
|
|
"model": agent.model,
|
|
"system_prompt": agent.system_prompt,
|
|
"selected_tools": agent.selected_tools,
|
|
"temperature": agent.temperature,
|
|
"created_at": agent.created_at
|
|
}
|
|
for agent in agents
|
|
]
|
|
|
|
@app.get("/agents/{agent_id}")
|
|
async def get_agent(agent_id: int):
|
|
agent = await agent_manager.get_agent(agent_id)
|
|
if agent:
|
|
return {
|
|
"id": agent.id,
|
|
"name": agent.name,
|
|
"model": agent.model,
|
|
"system_prompt": agent.system_prompt,
|
|
"selected_tools": agent.selected_tools,
|
|
"temperature": agent.temperature,
|
|
"created_at": agent.created_at
|
|
}
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
|
|
@app.put("/agents/{agent_id}")
|
|
async def update_agent(agent_id: int, agent: Agent):
|
|
success = await agent_manager.update_agent(agent_id, **agent.dict())
|
|
if success:
|
|
return {"status": "success"}
|
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
|
|
@app.delete("/agents/{agent_id}")
|
|
async def delete_agent(agent_id: int):
|
|
success = await agent_manager.delete_agent(agent_id)
|
|
if success:
|
|
return {"status": "success"}
|
|
raise HTTPException(status_code=404, detail="Agent not found") |