mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
a933193851
commit
6c903fa761
|
@ -1,58 +0,0 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from core.db import Agent
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
import json
|
|
||||||
|
|
||||||
class AgentManager:
|
|
||||||
def __init__(self, db):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
async def create_agent(self, model: str, name: str, system_prompt: str, selected_tools: List[str], temperature: float = 0.5) -> int:
|
|
||||||
async with self.db.get_async_session() as session:
|
|
||||||
new_agent = Agent(
|
|
||||||
model=model,
|
|
||||||
name=name,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
selected_tools=selected_tools, # Store as a list directly
|
|
||||||
temperature=temperature,
|
|
||||||
created_at=datetime.now().isoformat()
|
|
||||||
)
|
|
||||||
session.add(new_agent)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(new_agent)
|
|
||||||
return new_agent.id
|
|
||||||
|
|
||||||
async def get_agent(self, agent_id: int) -> Optional[Agent]:
|
|
||||||
async with self.db.get_async_session() as session:
|
|
||||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
|
||||||
agent = result.scalar_one_or_none()
|
|
||||||
return agent
|
|
||||||
|
|
||||||
async def update_agent(self, agent_id: int, **kwargs) -> bool:
|
|
||||||
async with self.db.get_async_session() as session:
|
|
||||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
|
||||||
agent = result.scalar_one_or_none()
|
|
||||||
if agent:
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(agent, key, value)
|
|
||||||
await session.commit()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def delete_agent(self, agent_id: int) -> bool:
|
|
||||||
async with self.db.get_async_session() as session:
|
|
||||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
|
||||||
agent = result.scalar_one_or_none()
|
|
||||||
if agent:
|
|
||||||
await session.delete(agent)
|
|
||||||
await session.commit()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def list_agents(self) -> List[Agent]:
|
|
||||||
async with self.db.get_async_session() as session:
|
|
||||||
result = await session.execute(select(Agent))
|
|
||||||
agents = result.scalars().all()
|
|
||||||
return agents
|
|
231
core/api.py
231
core/api.py
|
@ -1,109 +1,107 @@
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Query, Path
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from core.db import Database
|
from core.db import Database
|
||||||
from core.thread_manager import ThreadManager
|
from core.thread_manager import ThreadManager
|
||||||
from core.agent_manager import AgentManager
|
from core.tool_registry import ToolRegistry
|
||||||
from core.tools.tool_registry import ToolRegistry
|
|
||||||
from core.config import Settings
|
from core.config import Settings
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI(
|
||||||
|
title="Thread Manager API",
|
||||||
|
description="API for managing and running threads with LLM integration",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
db = Database()
|
db = Database()
|
||||||
manager = ThreadManager(db)
|
manager = ThreadManager(db)
|
||||||
tool_registry = ToolRegistry() # Initialize here
|
tool_registry = ToolRegistry()
|
||||||
agent_manager = AgentManager(db)
|
|
||||||
|
|
||||||
# Pydantic models for request and response bodies
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
|
||||||
content: str
|
content: str = Field(..., description="The content of the message")
|
||||||
|
|
||||||
class RunThreadRequest(BaseModel):
|
class RunThreadRequest(BaseModel):
|
||||||
system_message: Dict[str, Any]
|
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
|
||||||
model_name: str
|
model_name: str = Field(..., description="The name of the LLM model to be used")
|
||||||
temperature: float = 0.5
|
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
|
||||||
max_tokens: Optional[int] = 500
|
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
|
||||||
tools: Optional[List[str]] = None
|
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
|
||||||
tool_choice: str = "required"
|
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
|
||||||
additional_instructions: Optional[str] = None
|
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended")
|
||||||
stream: bool = False
|
stream: bool = Field(False, description="Whether to stream the response")
|
||||||
|
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
|
||||||
|
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
|
||||||
|
|
||||||
class Agent(BaseModel):
|
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
|
||||||
name: str
|
|
||||||
model: str
|
|
||||||
system_prompt: str
|
|
||||||
selected_tools: List[str]
|
|
||||||
temperature: float = 0.5
|
|
||||||
|
|
||||||
@app.post("/threads/")
|
|
||||||
async def create_thread():
|
async def create_thread():
|
||||||
|
"""
|
||||||
|
Create a new thread and return its ID.
|
||||||
|
"""
|
||||||
thread_id = await manager.create_thread()
|
thread_id = await manager.create_thread()
|
||||||
return {"thread_id": thread_id}
|
return {"thread_id": thread_id}
|
||||||
|
|
||||||
@app.get("/threads/")
|
@app.get("/threads/", response_model=List[Dict[str, str]], summary="Get all threads")
|
||||||
async def get_threads():
|
async def get_threads():
|
||||||
|
"""
|
||||||
|
Retrieve a list of all thread IDs.
|
||||||
|
"""
|
||||||
threads = await manager.get_threads()
|
threads = await manager.get_threads()
|
||||||
return [{"thread_id": thread.thread_id} for thread in threads]
|
return [{"thread_id": thread.thread_id} for thread in threads]
|
||||||
|
|
||||||
@app.post("/threads/{thread_id}/messages/")
|
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
|
||||||
async def add_message(thread_id: int, message: Message):
|
async def add_message(thread_id: str, message: Message):
|
||||||
|
"""
|
||||||
|
Add a new message to the specified thread.
|
||||||
|
"""
|
||||||
await manager.add_message(thread_id, message.dict())
|
await manager.add_message(thread_id, message.dict())
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
|
|
||||||
@app.get("/threads/{thread_id}/messages/")
|
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
|
||||||
async def list_messages(thread_id: int):
|
async def list_messages(thread_id: str):
|
||||||
|
"""
|
||||||
|
Retrieve all messages from the specified thread.
|
||||||
|
"""
|
||||||
messages = await manager.list_messages(thread_id)
|
messages = await manager.list_messages(thread_id)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@app.post("/threads/{thread_id}/run/")
|
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
|
||||||
async def run_thread(thread_id: int, request: Dict[str, Any]):
|
async def run_thread(thread_id: str, request: RunThreadRequest):
|
||||||
if 'agent_id' in request:
|
"""
|
||||||
# Agent-based run
|
Run the specified thread with the given parameters.
|
||||||
response_gen = manager.run_thread(
|
"""
|
||||||
thread_id=thread_id,
|
response = await manager.run_thread(
|
||||||
agent_id=request['agent_id'],
|
thread_id=thread_id,
|
||||||
additional_instructions=request.get('additional_instructions'),
|
system_message=request.system_message,
|
||||||
stream=request.get('stream', False)
|
model_name=request.model_name,
|
||||||
)
|
temperature=request.temperature,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
tools=request.tools,
|
||||||
|
additional_system_message=request.additional_system_message,
|
||||||
|
top_p=request.top_p,
|
||||||
|
tool_choice=request.tool_choice,
|
||||||
|
response_format=request.response_format)
|
||||||
|
|
||||||
|
return {"status": "success", "response": response}
|
||||||
|
|
||||||
|
@app.get("/threads/{thread_id}/run/status/", response_model=Dict[str, Any], summary="Get thread run status")
|
||||||
|
async def get_thread_run_status(thread_id: str):
|
||||||
|
"""
|
||||||
|
Retrieve the status of the latest run for the specified thread.
|
||||||
|
"""
|
||||||
|
latest_thread_run = await manager.get_latest_thread_run(thread_id)
|
||||||
|
if latest_thread_run:
|
||||||
|
return latest_thread_run
|
||||||
else:
|
else:
|
||||||
# Manual configuration run
|
return {"status": "No runs found for this thread."}
|
||||||
response_gen = manager.run_thread(
|
|
||||||
thread_id=thread_id,
|
|
||||||
system_message=request['system_message'],
|
|
||||||
model_name=request['model_name'],
|
|
||||||
temperature=request['temperature'],
|
|
||||||
max_tokens=request['max_tokens'],
|
|
||||||
tools=request['tools'],
|
|
||||||
additional_instructions=request.get('additional_instructions'),
|
|
||||||
stream=request.get('stream', False)
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.get('stream', False):
|
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
|
||||||
raise HTTPException(status_code=501, detail="Streaming is not supported via this endpoint.")
|
|
||||||
else:
|
|
||||||
response = []
|
|
||||||
async for chunk in response_gen:
|
|
||||||
response.append(chunk)
|
|
||||||
return {"response": response}
|
|
||||||
|
|
||||||
@app.get("/threads/{thread_id}/run/status/")
|
|
||||||
async def get_thread_run_status(thread_id: int):
|
|
||||||
try:
|
|
||||||
latest_thread_run = await manager.get_latest_thread_run(thread_id)
|
|
||||||
if latest_thread_run:
|
|
||||||
return {
|
|
||||||
"status": latest_thread_run.status,
|
|
||||||
"error_message": latest_thread_run.error_message
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {"status": "No runs found for this thread."}
|
|
||||||
except AttributeError:
|
|
||||||
return {"status": "Error", "message": "Unable to retrieve thread run status."}
|
|
||||||
|
|
||||||
@app.get("/tools/")
|
|
||||||
async def get_tools():
|
async def get_tools():
|
||||||
|
"""
|
||||||
|
Retrieve a list of all available tools and their schemas.
|
||||||
|
"""
|
||||||
tools = tool_registry.get_all_tools()
|
tools = tool_registry.get_all_tools()
|
||||||
if not tools:
|
if not tools:
|
||||||
print("No tools found in the registry") # Debug print
|
print("No tools found in the registry") # Debug print
|
||||||
|
@ -116,52 +114,45 @@ async def get_tools():
|
||||||
for name, tool_info in tools.items()
|
for name, tool_info in tools.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.post("/agents/")
|
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
|
||||||
async def create_agent(agent: Agent):
|
async def get_run(
|
||||||
agent_id = await agent_manager.create_agent(**agent.dict())
|
thread_id: str = Path(..., description="The ID of the thread that was run"),
|
||||||
return {"agent_id": agent_id}
|
run_id: str = Path(..., description="The ID of the run to retrieve")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve the run object matching the specified ID.
|
||||||
|
"""
|
||||||
|
run = await manager.get_run(thread_id, run_id)
|
||||||
|
if run is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
return run
|
||||||
|
|
||||||
@app.get("/agents/")
|
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
|
||||||
async def list_agents():
|
async def cancel_run(
|
||||||
agents = await agent_manager.list_agents()
|
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
|
||||||
return [
|
run_id: str = Path(..., description="The ID of the run to cancel")
|
||||||
{
|
):
|
||||||
"id": agent.id,
|
"""
|
||||||
"name": agent.name,
|
Cancels a run that is in_progress.
|
||||||
"model": agent.model,
|
"""
|
||||||
"system_prompt": agent.system_prompt,
|
run = await manager.cancel_run(thread_id, run_id)
|
||||||
"selected_tools": agent.selected_tools,
|
if run is None:
|
||||||
"temperature": agent.temperature,
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
"created_at": agent.created_at
|
return run
|
||||||
}
|
|
||||||
for agent in agents
|
|
||||||
]
|
|
||||||
|
|
||||||
@app.get("/agents/{agent_id}")
|
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
|
||||||
async def get_agent(agent_id: int):
|
async def list_runs(
|
||||||
agent = await agent_manager.get_agent(agent_id)
|
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
|
||||||
if agent:
|
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
|
||||||
return {
|
):
|
||||||
"id": agent.id,
|
"""
|
||||||
"name": agent.name,
|
Returns a list of runs belonging to a thread.
|
||||||
"model": agent.model,
|
"""
|
||||||
"system_prompt": agent.system_prompt,
|
runs = await manager.list_runs(thread_id, limit)
|
||||||
"selected_tools": agent.selected_tools,
|
return runs
|
||||||
"temperature": agent.temperature,
|
|
||||||
"created_at": agent.created_at
|
|
||||||
}
|
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
|
||||||
|
|
||||||
@app.put("/agents/{agent_id}")
|
# Add more endpoints as needed for production use
|
||||||
async def update_agent(agent_id: int, agent: Agent):
|
|
||||||
success = await agent_manager.update_agent(agent_id, **agent.dict())
|
|
||||||
if success:
|
|
||||||
return {"status": "success"}
|
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
|
||||||
|
|
||||||
@app.delete("/agents/{agent_id}")
|
if __name__ == "__main__":
|
||||||
async def delete_agent(agent_id: int):
|
import uvicorn
|
||||||
success = await agent_manager.delete_agent(agent_id)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
if success:
|
|
||||||
return {"status": "success"}
|
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
|
44
core/db.py
44
core/db.py
|
@ -1,17 +1,19 @@
|
||||||
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON
|
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON, Boolean
|
||||||
from sqlalchemy.orm import relationship, declarative_base
|
from sqlalchemy.orm import relationship, declarative_base
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from core.config import settings # Changed from Settings to settings
|
from core.config import settings # Changed from Settings to settings
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
class Thread(Base):
|
class Thread(Base):
|
||||||
__tablename__ = 'threads'
|
__tablename__ = 'threads'
|
||||||
|
|
||||||
thread_id = Column(Integer, primary_key=True)
|
thread_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
messages = Column(Text)
|
messages = Column(Text)
|
||||||
creation_date = Column(String)
|
creation_date = Column(String)
|
||||||
last_updated_date = Column(String)
|
last_updated_date = Column(String)
|
||||||
|
@ -22,25 +24,31 @@ class Thread(Base):
|
||||||
class ThreadRun(Base):
|
class ThreadRun(Base):
|
||||||
__tablename__ = 'thread_runs'
|
__tablename__ = 'thread_runs'
|
||||||
|
|
||||||
run_id = Column(Integer, primary_key=True)
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
thread_id = Column(Integer, ForeignKey('threads.thread_id'))
|
thread_id = Column(String(36), ForeignKey('threads.thread_id'))
|
||||||
messages = Column(Text)
|
created_at = Column(Integer)
|
||||||
creation_date = Column(String)
|
status = Column(String)
|
||||||
status = Column(String)
|
last_error = Column(Text, nullable=True)
|
||||||
error_message = Column(Text, nullable=True)
|
started_at = Column(Integer, nullable=True)
|
||||||
|
cancelled_at = Column(Integer, nullable=True)
|
||||||
|
failed_at = Column(Integer, nullable=True)
|
||||||
|
completed_at = Column(Integer, nullable=True)
|
||||||
|
model = Column(String)
|
||||||
|
system_message = Column(Text)
|
||||||
|
tools = Column(JSON, nullable=True)
|
||||||
|
usage = Column(JSON, nullable=True)
|
||||||
|
temperature = Column(Float, nullable=True)
|
||||||
|
top_p = Column(Float, nullable=True)
|
||||||
|
max_tokens = Column(Integer, nullable=True)
|
||||||
|
tool_choice = Column(String, nullable=True)
|
||||||
|
execute_tools_async = Column(Boolean)
|
||||||
|
response_format = Column(JSON, nullable=True)
|
||||||
|
|
||||||
thread = relationship("Thread", back_populates="thread_runs")
|
thread = relationship("Thread", back_populates="thread_runs")
|
||||||
|
|
||||||
class Agent(Base):
|
def __init__(self, **kwargs):
|
||||||
__tablename__ = 'agents'
|
super().__init__(**kwargs)
|
||||||
|
self.created_at = int(datetime.utcnow().timestamp())
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
model = Column(String, nullable=False)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
system_prompt = Column(Text, nullable=False)
|
|
||||||
selected_tools = Column(JSON) # Changed from ARRAY to JSON
|
|
||||||
temperature = Column(Float, default=0.5)
|
|
||||||
created_at = Column(String, nullable=False)
|
|
||||||
|
|
||||||
# class MemoryModule(Base):
|
# class MemoryModule(Base):
|
||||||
# __tablename__ = 'memory_modules'
|
# __tablename__ = 'memory_modules'
|
||||||
|
|
67
core/llm.py
67
core/llm.py
|
@ -1,4 +1,4 @@
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, Dict, Any
|
||||||
import litellm
|
import litellm
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
@ -26,24 +26,13 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False) -> AsyncGenerator[Union[dict, str], None]:
|
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
async def attempt_api_call(api_call_func, max_attempts=3):
|
async def attempt_api_call(api_call_func, max_attempts=3):
|
||||||
for attempt in range(max_attempts):
|
for attempt in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
response = await api_call_func()
|
return await api_call_func()
|
||||||
if stream:
|
|
||||||
async for chunk in response:
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
response_content = response.choices[0].message['content'] if json_mode else response
|
|
||||||
if json_mode:
|
|
||||||
if not json.loads(response_content):
|
|
||||||
logger.info(f"Invalid JSON received, retrying attempt {attempt + 1}")
|
|
||||||
continue
|
|
||||||
yield response
|
|
||||||
return
|
|
||||||
except litellm.exceptions.RateLimitError as e:
|
except litellm.exceptions.RateLimitError as e:
|
||||||
logger.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
|
logger.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
|
||||||
await asyncio.sleep(30)
|
await asyncio.sleep(30)
|
||||||
|
@ -60,7 +49,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"response_format": {"type": "json_object"} if json_mode else None,
|
"response_format": response_format or ({"type": "json_object"} if json_mode else None),
|
||||||
|
"top_p": top_p,
|
||||||
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add api_key and api_base if provided
|
# Add api_key and api_base if provided
|
||||||
|
@ -88,12 +79,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
||||||
api_call_params["tool_choice"] = tool_choice
|
api_call_params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
||||||
# if messages[0]["role"] != "user":
|
|
||||||
# api_call_params["messages"] = [{"role": "user", "content": "."}] + messages
|
|
||||||
api_call_params["extra_headers"] = {
|
api_call_params["extra_headers"] = {
|
||||||
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
||||||
}
|
}
|
||||||
# # "anthropic-beta": "prompt-caching-2024-07-31"
|
|
||||||
|
|
||||||
# Log the API request
|
# Log the API request
|
||||||
logger.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
|
logger.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
|
||||||
|
@ -101,20 +89,49 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
||||||
if agentops_session:
|
if agentops_session:
|
||||||
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
|
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
|
||||||
else:
|
else:
|
||||||
if stream:
|
response = await litellm.acompletion(**api_call_params)
|
||||||
response = await litellm.acompletion(**api_call_params, stream=True)
|
|
||||||
else:
|
|
||||||
response = await litellm.acompletion(**api_call_params)
|
|
||||||
|
|
||||||
# Log the API response
|
# Log the API response
|
||||||
logger.info(f"Received API response: {response}")
|
logger.info(f"Received API response: {response}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async for result in attempt_api_call(api_call):
|
return await attempt_api_call(api_call)
|
||||||
yield result
|
|
||||||
|
|
||||||
# Sample Usage
|
# Sample Usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
async def test_llm_api_call(stream=True):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Complex essay on economics"}
|
||||||
|
]
|
||||||
|
model_name = "gpt-4o"
|
||||||
|
|
||||||
pass
|
response = await make_llm_api_call(messages, model_name, stream=stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
print("Streaming response:")
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, dict) and 'choices' in chunk:
|
||||||
|
content = chunk['choices'][0]['delta'].get('content', '')
|
||||||
|
print(content, end='', flush=True)
|
||||||
|
else:
|
||||||
|
# For non-dict responses (like ModelResponse objects)
|
||||||
|
content = chunk.choices[0].delta.content
|
||||||
|
if content:
|
||||||
|
print(content, end='', flush=True)
|
||||||
|
print("\nStream completed.")
|
||||||
|
else:
|
||||||
|
print("Non-streaming response:")
|
||||||
|
if isinstance(response, dict) and 'choices' in response:
|
||||||
|
print(response['choices'][0]['message']['content'])
|
||||||
|
else:
|
||||||
|
# For non-dict responses (like ModelResponse objects)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# asyncio.run(test_llm_api_call(stream=True)) # For streaming
|
||||||
|
# asyncio.run(test_llm_api_call(stream=False)) # For non-streaming
|
||||||
|
|
||||||
|
asyncio.run(test_llm_api_call())
|
|
@ -5,19 +5,18 @@ from typing import List, Dict, Any, Optional, Callable, AsyncGenerator, Union
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from core.db import Database, Thread, ThreadRun
|
from core.db import Database, Thread, ThreadRun
|
||||||
from core.tools.tool import Tool, ToolResult
|
from core.tool import Tool, ToolResult
|
||||||
from core.llm import make_llm_api_call
|
from core.llm import make_llm_api_call
|
||||||
# from core.working_memory_manager import WorkingMemory
|
# from core.working_memory_manager import WorkingMemory
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from core.tools.tool_registry import ToolRegistry
|
from core.tool_registry import ToolRegistry
|
||||||
import re
|
import re
|
||||||
from core.agent_manager import AgentManager
|
import uuid
|
||||||
|
|
||||||
class ThreadManager:
|
class ThreadManager:
|
||||||
def __init__(self, db: Database):
|
def __init__(self, db: Database):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.tool_registry = ToolRegistry()
|
self.tool_registry = ToolRegistry()
|
||||||
self.agent_manager = AgentManager(db)
|
|
||||||
|
|
||||||
async def create_thread(self) -> int:
|
async def create_thread(self) -> int:
|
||||||
async with self.db.get_async_session() as session:
|
async with self.db.get_async_session() as session:
|
||||||
|
@ -197,84 +196,176 @@ class ThreadManager:
|
||||||
|
|
||||||
async def run_thread(
|
async def run_thread(
|
||||||
self,
|
self,
|
||||||
thread_id: int,
|
thread_id: str,
|
||||||
agent_id: Optional[int] = None,
|
system_message: Dict[str, Any],
|
||||||
system_message: Optional[Dict[str, Any]] = None,
|
model_name: str,
|
||||||
model_name: Optional[str] = None,
|
temperature: float = 0.5,
|
||||||
temperature: Optional[float] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[str]] = None,
|
||||||
additional_instructions: Optional[str] = None,
|
additional_system_message: Optional[str] = None,
|
||||||
hide_tool_msgs: bool = False,
|
hide_tool_msgs: bool = False,
|
||||||
execute_tools_async: bool = True,
|
execute_tools_async: bool = True,
|
||||||
use_tool_parser: bool = False,
|
use_tool_parser: bool = False,
|
||||||
stream: bool = False
|
top_p: Optional[float] = None,
|
||||||
) -> AsyncGenerator[Union[Dict[str, Any], str], None]:
|
tool_choice: str = "auto",
|
||||||
if agent_id is not None:
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
agent = await self.agent_manager.get_agent(agent_id)
|
) -> Dict[str, Any]:
|
||||||
if not agent:
|
run_id = str(uuid.uuid4())
|
||||||
raise ValueError(f"Agent with id {agent_id} not found")
|
|
||||||
system_message = {"role": "system", "content": agent.system_prompt}
|
# Fetch full tool objects based on the provided tool names
|
||||||
model_name = agent.model
|
full_tools = None
|
||||||
temperature = agent.temperature
|
if tools:
|
||||||
tools = [self.tool_registry.get_tool(tool).schema()[0] for tool in agent.selected_tools] if agent.selected_tools else None
|
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in tools if self.tool_registry.get_tool(tool_name)]
|
||||||
elif system_message is None or model_name is None:
|
|
||||||
raise ValueError("Either agent_id or system_message and model_name must be provided")
|
thread_run = ThreadRun(
|
||||||
|
id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
status="queued",
|
||||||
|
model=model_name,
|
||||||
|
system_message=json.dumps(system_message),
|
||||||
|
tools=json.dumps(full_tools) if full_tools else None,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
execute_tools_async=execute_tools_async,
|
||||||
|
response_format=json.dumps(response_format) if response_format else None
|
||||||
|
)
|
||||||
|
|
||||||
if await self.should_stop(thread_id):
|
async with self.db.get_async_session() as session:
|
||||||
yield {"status": "stopped", "message": "Session cancelled"}
|
session.add(thread_run)
|
||||||
return
|
await session.commit()
|
||||||
|
|
||||||
if use_tool_parser:
|
|
||||||
hide_tool_msgs = True
|
|
||||||
|
|
||||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
|
||||||
|
|
||||||
# Prepare messages
|
|
||||||
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
|
|
||||||
prepared_messages = [system_message] + messages
|
|
||||||
|
|
||||||
if additional_instructions:
|
|
||||||
additional_instruction_message = {
|
|
||||||
"role": "user",
|
|
||||||
"content": additional_instructions
|
|
||||||
}
|
|
||||||
prepared_messages.append(additional_instruction_message)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_stream = make_llm_api_call(
|
thread_run.status = "in_progress"
|
||||||
|
thread_run.started_at = int(datetime.utcnow().timestamp())
|
||||||
|
await self.update_thread_run(thread_run)
|
||||||
|
|
||||||
|
if await self.should_stop(thread_id):
|
||||||
|
thread_run.status = "cancelled"
|
||||||
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
||||||
|
await self.update_thread_run(thread_run)
|
||||||
|
return {"status": "stopped", "message": "Session cancelled"}
|
||||||
|
|
||||||
|
if use_tool_parser:
|
||||||
|
hide_tool_msgs = True
|
||||||
|
|
||||||
|
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
|
||||||
|
prepared_messages = [system_message] + messages
|
||||||
|
|
||||||
|
if additional_system_message:
|
||||||
|
additional_instruction_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": additional_system_message
|
||||||
|
}
|
||||||
|
prepared_messages.append(additional_instruction_message)
|
||||||
|
|
||||||
|
response = await make_llm_api_call(
|
||||||
prepared_messages,
|
prepared_messages,
|
||||||
model_name,
|
model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tools=tools,
|
tools=full_tools,
|
||||||
tool_choice="auto",
|
tool_choice=tool_choice,
|
||||||
stream=stream
|
stream=False,
|
||||||
|
top_p=top_p,
|
||||||
|
response_format=response_format
|
||||||
)
|
)
|
||||||
|
|
||||||
async for partial_response in response_stream:
|
usage = response.usage if hasattr(response, 'usage') else None
|
||||||
if stream:
|
usage_dict = self.serialize_usage(usage) if usage else None
|
||||||
yield partial_response
|
thread_run.usage = usage_dict
|
||||||
else:
|
|
||||||
response = partial_response
|
|
||||||
|
|
||||||
if not stream:
|
# Add the assistant's message to the thread
|
||||||
if tools is None or use_tool_parser:
|
assistant_message = {
|
||||||
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
|
"role": "assistant",
|
||||||
else:
|
"content": response.choices[0].message['content']
|
||||||
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
|
}
|
||||||
|
if 'tool_calls' in response.choices[0].message:
|
||||||
|
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
|
||||||
|
|
||||||
|
await self.add_message(thread_id, assistant_message)
|
||||||
|
|
||||||
if await self.should_stop(thread_id):
|
if tools is None or use_tool_parser:
|
||||||
yield {"status": "stopped", "message": "Session cancelled"}
|
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
|
||||||
else:
|
else:
|
||||||
await self.save_thread_run(thread_id)
|
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
|
||||||
yield response
|
|
||||||
|
thread_run.status = "completed"
|
||||||
|
thread_run.completed_at = int(datetime.utcnow().timestamp())
|
||||||
|
await self.update_thread_run(thread_run)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": thread_run.id,
|
||||||
|
"choices": [self.serialize_choice(choice) for choice in response.choices],
|
||||||
|
"usage": usage_dict,
|
||||||
|
"model": model_name,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(datetime.utcnow().timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error in API call: {str(e)}\n\nFull error: {repr(e)}"
|
error_message = f"Error in API call: {str(e)}\n\nFull error: {repr(e)}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
await self.update_thread_run_with_error(thread_id, error_message)
|
thread_run.status = "failed"
|
||||||
yield {"status": "error", "message": error_message}
|
thread_run.failed_at = int(datetime.utcnow().timestamp())
|
||||||
|
thread_run.last_error = error_message
|
||||||
|
await self.update_thread_run(thread_run)
|
||||||
|
return {"status": "error", "message": error_message}
|
||||||
|
|
||||||
|
def serialize_usage(self, usage):
|
||||||
|
return {
|
||||||
|
"completion_tokens": usage.completion_tokens,
|
||||||
|
"prompt_tokens": usage.prompt_tokens,
|
||||||
|
"total_tokens": usage.total_tokens,
|
||||||
|
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
|
||||||
|
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_completion_tokens_details(self, details):
|
||||||
|
return {
|
||||||
|
"audio_tokens": details.audio_tokens,
|
||||||
|
"reasoning_tokens": details.reasoning_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_prompt_tokens_details(self, details):
|
||||||
|
return {
|
||||||
|
"audio_tokens": details.audio_tokens,
|
||||||
|
"cached_tokens": details.cached_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_choice(self, choice):
|
||||||
|
return {
|
||||||
|
"finish_reason": choice.finish_reason,
|
||||||
|
"index": choice.index,
|
||||||
|
"message": self.serialize_message(choice.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_message(self, message):
|
||||||
|
return {
|
||||||
|
"content": message.content,
|
||||||
|
"role": message.role,
|
||||||
|
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_tool_call(self, tool_call):
|
||||||
|
return {
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": tool_call.type,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_thread_run(self, thread_run: ThreadRun):
|
||||||
|
async with self.db.get_async_session() as session:
|
||||||
|
session.add(thread_run)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(thread_run)
|
||||||
|
|
||||||
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
|
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
|
||||||
response_content = response.choices[0].message['content']
|
response_content = response.choices[0].message['content']
|
||||||
|
@ -282,8 +373,8 @@ class ThreadManager:
|
||||||
if use_tool_parser:
|
if use_tool_parser:
|
||||||
await self.handle_tool_parser_response(thread_id, response_content)
|
await self.handle_tool_parser_response(thread_id, response_content)
|
||||||
else:
|
else:
|
||||||
logging.info("Adding assistant message to thread.")
|
# The message has already been added in the run_thread method, so we don't need to add it again here
|
||||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
pass
|
||||||
|
|
||||||
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
|
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
|
||||||
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
|
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
|
||||||
|
@ -325,9 +416,8 @@ class ThreadManager:
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
tool_calls = response_message.get('tool_calls', [])
|
tool_calls = response_message.get('tool_calls', [])
|
||||||
|
|
||||||
assistant_message = self.create_assistant_message_with_tools(response_message)
|
# The assistant message has already been added in the run_thread method
|
||||||
await self.add_message(thread_id, assistant_message)
|
|
||||||
|
|
||||||
available_functions = self.get_available_functions()
|
available_functions = self.get_available_functions()
|
||||||
|
|
||||||
if await self.should_stop(thread_id):
|
if await self.should_stop(thread_id):
|
||||||
|
@ -348,9 +438,7 @@ class ThreadManager:
|
||||||
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logging.error(f"AttributeError: {e}")
|
logging.error(f"AttributeError: {e}")
|
||||||
content = response_message.get('content', '')
|
# No need to add the message here as it's already been added in the run_thread method
|
||||||
if content:
|
|
||||||
await self.add_message(thread_id, {"role": "assistant", "content": content})
|
|
||||||
|
|
||||||
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
||||||
message = {
|
message = {
|
||||||
|
@ -452,7 +540,7 @@ class ThreadManager:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return result.scalar_one_or_none() is not None
|
return result.scalar_one_or_none() is not None
|
||||||
|
|
||||||
async def save_thread_run(self, thread_id: int):
|
async def save_thread_run(self, thread_id: str):
|
||||||
async with self.db.get_async_session() as session:
|
async with self.db.get_async_session() as session:
|
||||||
thread = await session.get(Thread, thread_id)
|
thread = await session.get(Thread, thread_id)
|
||||||
if not thread:
|
if not thread:
|
||||||
|
@ -461,14 +549,26 @@ class ThreadManager:
|
||||||
messages = json.loads(thread.messages)
|
messages = json.loads(thread.messages)
|
||||||
creation_date = datetime.now().isoformat()
|
creation_date = datetime.now().isoformat()
|
||||||
|
|
||||||
new_thread_run = ThreadRun(
|
# Get the latest ThreadRun for this thread
|
||||||
thread_id=thread_id,
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||||
messages=json.dumps(messages),
|
result = await session.execute(stmt)
|
||||||
creation_date=creation_date,
|
latest_thread_run = result.scalar_one_or_none()
|
||||||
status='completed'
|
|
||||||
)
|
if latest_thread_run:
|
||||||
session.add(new_thread_run)
|
# Update the existing ThreadRun
|
||||||
await session.commit()
|
latest_thread_run.messages = json.dumps(messages)
|
||||||
|
latest_thread_run.last_updated_date = creation_date
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
# Create a new ThreadRun if none exists
|
||||||
|
new_thread_run = ThreadRun(
|
||||||
|
thread_id=thread_id,
|
||||||
|
messages=json.dumps(messages),
|
||||||
|
creation_date=creation_date,
|
||||||
|
status='completed'
|
||||||
|
)
|
||||||
|
session.add(new_thread_run)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def get_thread(self, thread_id: int) -> Optional[Thread]:
|
async def get_thread(self, thread_id: int) -> Optional[Thread]:
|
||||||
async with self.db.get_async_session() as session:
|
async with self.db.get_async_session() as session:
|
||||||
|
@ -489,11 +589,92 @@ class ThreadManager:
|
||||||
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
|
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
async def get_latest_thread_run(self, thread_id: int):
|
async def get_latest_thread_run(self, thread_id: str):
|
||||||
async with self.db.get_async_session() as session:
|
async with self.db.get_async_session() as session:
|
||||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
latest_run = result.scalar_one_or_none()
|
||||||
|
if latest_run:
|
||||||
|
return {
|
||||||
|
"id": latest_run.id,
|
||||||
|
"status": latest_run.status,
|
||||||
|
"error_message": latest_run.last_error,
|
||||||
|
"created_at": latest_run.created_at,
|
||||||
|
"started_at": latest_run.started_at,
|
||||||
|
"completed_at": latest_run.completed_at,
|
||||||
|
"cancelled_at": latest_run.cancelled_at,
|
||||||
|
"failed_at": latest_run.failed_at,
|
||||||
|
"model": latest_run.model,
|
||||||
|
"usage": latest_run.usage
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
async with self.db.get_async_session() as session:
|
||||||
|
run = await session.get(ThreadRun, run_id)
|
||||||
|
if run and run.thread_id == thread_id:
|
||||||
|
return {
|
||||||
|
"id": run.id,
|
||||||
|
"thread_id": run.thread_id,
|
||||||
|
"status": run.status,
|
||||||
|
"created_at": run.created_at,
|
||||||
|
"started_at": run.started_at,
|
||||||
|
"completed_at": run.completed_at,
|
||||||
|
"cancelled_at": run.cancelled_at,
|
||||||
|
"failed_at": run.failed_at,
|
||||||
|
"model": run.model,
|
||||||
|
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||||
|
"tools": json.loads(run.tools) if run.tools else None,
|
||||||
|
"usage": run.usage,
|
||||||
|
"temperature": run.temperature,
|
||||||
|
"top_p": run.top_p,
|
||||||
|
"max_tokens": run.max_tokens,
|
||||||
|
"tool_choice": run.tool_choice,
|
||||||
|
"execute_tools_async": run.execute_tools_async,
|
||||||
|
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||||
|
"last_error": run.last_error
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
async with self.db.get_async_session() as session:
|
||||||
|
run = await session.get(ThreadRun, run_id)
|
||||||
|
if run and run.thread_id == thread_id and run.status == "in_progress":
|
||||||
|
run.status = "cancelled"
|
||||||
|
run.cancelled_at = int(datetime.utcnow().timestamp())
|
||||||
|
await session.commit()
|
||||||
|
return await self.get_run(thread_id, run_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
|
||||||
|
async with self.db.get_async_session() as session:
|
||||||
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
runs = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": run.id,
|
||||||
|
"thread_id": run.thread_id,
|
||||||
|
"status": run.status,
|
||||||
|
"created_at": run.created_at,
|
||||||
|
"started_at": run.started_at,
|
||||||
|
"completed_at": run.completed_at,
|
||||||
|
"cancelled_at": run.cancelled_at,
|
||||||
|
"failed_at": run.failed_at,
|
||||||
|
"model": run.model,
|
||||||
|
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||||
|
"tools": json.loads(run.tools) if run.tools else None,
|
||||||
|
"usage": run.usage,
|
||||||
|
"temperature": run.temperature,
|
||||||
|
"top_p": run.top_p,
|
||||||
|
"max_tokens": run.max_tokens,
|
||||||
|
"tool_choice": run.tool_choice,
|
||||||
|
"execute_tools_async": run.execute_tools_async,
|
||||||
|
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||||
|
"last_error": run.last_error
|
||||||
|
}
|
||||||
|
for run in runs
|
||||||
|
]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -519,6 +700,4 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(f"Response: {response}")
|
print(f"Response: {response}")
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Dict, Type, Any
|
from typing import Dict, Type, Any, Optional
|
||||||
from core.tools.tool import Tool
|
from core.tool import Tool
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
@ -39,7 +39,7 @@ class ToolRegistry:
|
||||||
|
|
||||||
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
|
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||||
return self.tools.get(tool_name)
|
return self.tools.get(tool_name)
|
||||||
|
|
||||||
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
|
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
|
|
@ -1,97 +0,0 @@
|
||||||
import streamlit as st
|
|
||||||
import requests
|
|
||||||
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
|
||||||
|
|
||||||
def display_agent_management():
|
|
||||||
st.header("Agent Management")
|
|
||||||
|
|
||||||
col1, col2 = st.columns([1, 1])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
display_create_agent_form()
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
display_existing_agents()
|
|
||||||
|
|
||||||
def display_create_agent_form():
|
|
||||||
st.subheader("Create New Agent")
|
|
||||||
with st.form("create_agent_form"):
|
|
||||||
new_agent_name = st.text_input("Agent Name")
|
|
||||||
new_agent_model = st.selectbox("Model", AI_MODELS)
|
|
||||||
new_agent_system_prompt = st.text_area("System Prompt", value=STANDARD_SYSTEM_MESSAGE)
|
|
||||||
new_agent_temperature = st.slider("Temperature", 0.0, 1.0, 0.5)
|
|
||||||
tool_options = list(st.session_state.tools.keys())
|
|
||||||
new_agent_tools = st.multiselect("Tools", tool_options)
|
|
||||||
|
|
||||||
submitted = st.form_submit_button("Create Agent")
|
|
||||||
if submitted:
|
|
||||||
create_agent(new_agent_name, new_agent_model, new_agent_system_prompt, new_agent_temperature, new_agent_tools)
|
|
||||||
|
|
||||||
def create_agent(name, model, system_prompt, temperature, tools):
|
|
||||||
response = requests.post(f"{API_BASE_URL}/agents/", json={
|
|
||||||
"name": name,
|
|
||||||
"model": model,
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"temperature": temperature,
|
|
||||||
"selected_tools": tools
|
|
||||||
})
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.success("Agent created successfully!")
|
|
||||||
st.session_state.fetch_agents()
|
|
||||||
else:
|
|
||||||
st.error("Failed to create agent.")
|
|
||||||
|
|
||||||
def display_existing_agents():
|
|
||||||
st.subheader("Existing Agents")
|
|
||||||
for agent in st.session_state.agents:
|
|
||||||
with st.expander(f"Agent: {agent['name']}"):
|
|
||||||
display_agent_details(agent)
|
|
||||||
|
|
||||||
def display_agent_details(agent):
|
|
||||||
st.write(f"Model: {agent['model']}")
|
|
||||||
st.write(f"Temperature: {agent['temperature']}")
|
|
||||||
st.write(f"Tools: {', '.join(agent['selected_tools'])}")
|
|
||||||
|
|
||||||
if st.button(f"Edit Agent {agent['id']}"):
|
|
||||||
st.session_state.editing_agent = agent['id']
|
|
||||||
|
|
||||||
if st.button(f"Delete Agent {agent['id']}"):
|
|
||||||
delete_agent(agent['id'])
|
|
||||||
|
|
||||||
if st.session_state.get('editing_agent') == agent['id']:
|
|
||||||
edit_agent_form(agent)
|
|
||||||
|
|
||||||
def edit_agent_form(agent):
|
|
||||||
with st.form(f"edit_agent_form_{agent['id']}"):
|
|
||||||
updated_name = st.text_input("Agent Name", value=agent['name'])
|
|
||||||
updated_model = st.selectbox("Model", AI_MODELS, index=AI_MODELS.index(agent['model']))
|
|
||||||
updated_system_prompt = st.text_area("System Prompt", value=agent['system_prompt'])
|
|
||||||
updated_temperature = st.slider("Temperature", 0.0, 1.0, value=agent['temperature'])
|
|
||||||
tool_options = list(st.session_state.tools.keys())
|
|
||||||
updated_tools = st.multiselect("Tools", options=tool_options, default=agent['selected_tools'])
|
|
||||||
|
|
||||||
if st.form_submit_button("Update Agent"):
|
|
||||||
update_agent(agent['id'], updated_name, updated_model, updated_system_prompt, updated_temperature, updated_tools)
|
|
||||||
st.session_state.editing_agent = None
|
|
||||||
|
|
||||||
def update_agent(agent_id, name, model, system_prompt, temperature, tools):
|
|
||||||
response = requests.put(f"{API_BASE_URL}/agents/{agent_id}", json={
|
|
||||||
"name": name,
|
|
||||||
"model": model,
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"temperature": temperature,
|
|
||||||
"selected_tools": tools
|
|
||||||
})
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.success("Agent updated successfully!")
|
|
||||||
st.session_state.fetch_agents()
|
|
||||||
else:
|
|
||||||
st.error("Failed to update agent.")
|
|
||||||
|
|
||||||
def delete_agent(agent_id):
|
|
||||||
response = requests.delete(f"{API_BASE_URL}/agents/{agent_id}")
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.success("Agent deleted successfully!")
|
|
||||||
st.session_state.fetch_agents()
|
|
||||||
else:
|
|
||||||
st.error("Failed to delete agent.")
|
|
|
@ -1,34 +1,45 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from core.ui.thread_management import display_thread_management
|
from core.ui.thread_management import display_thread_management
|
||||||
from core.ui.message_display import display_messages
|
from core.ui.message_display import display_messages_and_runner
|
||||||
from core.ui.thread_runner import display_thread_runner
|
from core.ui.thread_runner import fetch_thread_runs, display_runs
|
||||||
from core.ui.agent_management import display_agent_management
|
|
||||||
from core.ui.tool_display import display_tools
|
from core.ui.tool_display import display_tools
|
||||||
from core.ui.utils import initialize_session_state, fetch_data
|
from core.ui.utils import initialize_session_state, fetch_data, API_BASE_URL
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_session_state()
|
initialize_session_state()
|
||||||
fetch_data()
|
fetch_data()
|
||||||
|
|
||||||
|
st.set_page_config(page_title="AI Assistant Management System", layout="wide")
|
||||||
|
|
||||||
st.sidebar.title("Navigation")
|
st.sidebar.title("Navigation")
|
||||||
mode = st.sidebar.radio("Select Mode", ["Agent Management", "Thread Management", "Tools"])
|
mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"])
|
||||||
|
|
||||||
st.title("AI Assistant Management System")
|
st.title("AI Assistant Management System")
|
||||||
|
|
||||||
if mode == "Agent Management":
|
if mode == "Tools":
|
||||||
display_agent_management()
|
|
||||||
elif mode == "Tools":
|
|
||||||
display_tools()
|
display_tools()
|
||||||
else: # Thread Management
|
else: # Thread Management
|
||||||
display_thread_management_content()
|
display_thread_management_content()
|
||||||
|
|
||||||
def display_thread_management_content():
|
def display_thread_management_content():
|
||||||
st.header("Thread Management")
|
col1, col2 = st.columns([1, 3])
|
||||||
display_thread_management()
|
|
||||||
|
with col1:
|
||||||
|
display_thread_management()
|
||||||
|
if st.session_state.selected_thread:
|
||||||
|
display_thread_runner(st.session_state.selected_thread)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
if st.session_state.selected_thread:
|
||||||
|
display_messages_and_runner(st.session_state.selected_thread)
|
||||||
|
|
||||||
if st.session_state.selected_thread:
|
def display_thread_runner(thread_id):
|
||||||
display_messages(st.session_state.selected_thread)
|
st.subheader("Thread Runs")
|
||||||
display_thread_runner(st.session_state.selected_thread)
|
|
||||||
|
limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20)
|
||||||
|
if st.button("Fetch Runs"):
|
||||||
|
runs = fetch_thread_runs(thread_id, limit)
|
||||||
|
display_runs(runs)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
|
@ -1,16 +1,21 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import requests
|
import requests
|
||||||
import json
|
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||||
from core.ui.utils import API_BASE_URL
|
from core.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content
|
||||||
|
|
||||||
def display_messages(thread_id):
|
def display_messages_and_runner(thread_id):
|
||||||
st.subheader(f"🧵 Thread ID: {thread_id}")
|
st.subheader(f"Messages for Thread: {thread_id}")
|
||||||
# st.write("### 📝 Messages")
|
|
||||||
|
|
||||||
messages = fetch_messages(thread_id)
|
messages = fetch_messages(thread_id)
|
||||||
display_message_json(messages)
|
|
||||||
display_message_list(messages)
|
display_message_list(messages)
|
||||||
display_add_message_form(thread_id)
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
display_add_message_form(thread_id)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
display_run_thread_form(thread_id)
|
||||||
|
|
||||||
def fetch_messages(thread_id):
|
def fetch_messages(thread_id):
|
||||||
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
|
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
|
||||||
|
@ -20,28 +25,37 @@ def fetch_messages(thread_id):
|
||||||
st.error("Failed to fetch messages.")
|
st.error("Failed to fetch messages.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def display_message_json(messages):
|
|
||||||
if st.button("Show/Hide JSON"):
|
|
||||||
st.session_state.show_json = not st.session_state.get('show_json', False)
|
|
||||||
|
|
||||||
if st.session_state.get('show_json', False):
|
|
||||||
json_str = json.dumps(messages, indent=2)
|
|
||||||
st.code(json_str, language="json")
|
|
||||||
|
|
||||||
def display_message_list(messages):
|
def display_message_list(messages):
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
with st.chat_message(msg['role']):
|
with st.chat_message(msg['role']):
|
||||||
st.write(msg['content'])
|
st.write(msg['content'])
|
||||||
|
|
||||||
def display_add_message_form(thread_id):
|
def display_add_message_form(thread_id):
|
||||||
st.write("### ➕ Add a New Message")
|
st.write("### Add a New Message")
|
||||||
with st.form(key="add_message_form"):
|
with st.form(key="add_message_form"):
|
||||||
role = st.selectbox("🔹 Role", ["user", "assistant"], key="add_role")
|
role = st.selectbox("Role", ["user", "assistant"], key="add_role")
|
||||||
content = st.text_area("📝 Content", key="add_content")
|
content = st.text_area("Content", key="add_content")
|
||||||
submitted = st.form_submit_button("➕ Add Message")
|
submitted = st.form_submit_button("Add Message")
|
||||||
if submitted:
|
if submitted:
|
||||||
add_message(thread_id, role, content)
|
add_message(thread_id, role, content)
|
||||||
|
|
||||||
|
def display_run_thread_form(thread_id):
|
||||||
|
st.write("### Run Thread")
|
||||||
|
with st.form(key="run_thread_form"):
|
||||||
|
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
|
||||||
|
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
|
||||||
|
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
|
||||||
|
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
|
||||||
|
additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100)
|
||||||
|
|
||||||
|
tool_options = list(st.session_state.tools.keys())
|
||||||
|
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
|
||||||
|
|
||||||
|
submitted = st.form_submit_button("Run Thread")
|
||||||
|
if submitted:
|
||||||
|
run_thread_data = prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools)
|
||||||
|
run_thread(thread_id, run_thread_data)
|
||||||
|
|
||||||
def add_message(thread_id, role, content):
|
def add_message(thread_id, role, content):
|
||||||
message_data = {"role": role, "content": content}
|
message_data = {"role": role, "content": content}
|
||||||
add_msg_response = requests.post(
|
add_msg_response = requests.post(
|
||||||
|
@ -49,6 +63,7 @@ def add_message(thread_id, role, content):
|
||||||
json=message_data
|
json=message_data
|
||||||
)
|
)
|
||||||
if add_msg_response.status_code == 200:
|
if add_msg_response.status_code == 200:
|
||||||
|
st.success("Message added successfully.")
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.error("Failed to add message.")
|
st.error("Failed to add message.")
|
|
@ -1,22 +1,22 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import requests
|
import requests
|
||||||
from core.ui.utils import API_BASE_URL
|
from core.ui.utils import API_BASE_URL
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
def display_thread_management():
|
def display_thread_management():
|
||||||
col1, col2 = st.columns([1, 2])
|
st.subheader("Thread Management")
|
||||||
|
|
||||||
with col1:
|
if st.button("➕ Create New Thread", key="create_thread_button"):
|
||||||
if st.button("➕ Create New Thread"):
|
create_new_thread()
|
||||||
create_new_thread()
|
|
||||||
|
|
||||||
with col2:
|
display_thread_selector()
|
||||||
display_thread_selector()
|
|
||||||
|
|
||||||
def create_new_thread():
|
def create_new_thread():
|
||||||
response = requests.post(f"{API_BASE_URL}/threads/")
|
response = requests.post(f"{API_BASE_URL}/threads/")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
thread_id = response.json()['thread_id']
|
thread_id = response.json()['thread_id']
|
||||||
st.session_state.selected_thread = thread_id
|
st.session_state.selected_thread = thread_id
|
||||||
|
st.success(f"New thread created with ID: {thread_id}")
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.error("Failed to create a new thread.")
|
st.error("Failed to create a new thread.")
|
||||||
|
@ -25,17 +25,36 @@ def display_thread_selector():
|
||||||
threads_response = requests.get(f"{API_BASE_URL}/threads/")
|
threads_response = requests.get(f"{API_BASE_URL}/threads/")
|
||||||
if threads_response.status_code == 200:
|
if threads_response.status_code == 200:
|
||||||
threads = threads_response.json()
|
threads = threads_response.json()
|
||||||
thread_options = [str(thread['thread_id']) for thread in threads]
|
|
||||||
|
# Sort threads by creation date if available, otherwise by thread_id
|
||||||
|
def sort_key(thread):
|
||||||
|
if 'creation_date' in thread:
|
||||||
|
try:
|
||||||
|
return datetime.strptime(thread['creation_date'], "%Y-%m-%d %H:%M:%S")
|
||||||
|
except ValueError:
|
||||||
|
st.warning(f"Invalid date format for thread {thread['thread_id']}")
|
||||||
|
return thread['thread_id']
|
||||||
|
|
||||||
|
sorted_threads = sorted(threads, key=sort_key, reverse=True)
|
||||||
|
|
||||||
|
thread_options = []
|
||||||
|
for thread in sorted_threads:
|
||||||
|
if 'creation_date' in thread:
|
||||||
|
thread_options.append(f"{thread['thread_id']} - Created: {thread['creation_date']}")
|
||||||
|
else:
|
||||||
|
thread_options.append(thread['thread_id'])
|
||||||
|
|
||||||
if st.session_state.selected_thread is None and threads:
|
if st.session_state.selected_thread is None and sorted_threads:
|
||||||
st.session_state.selected_thread = str(threads[0]['thread_id'])
|
st.session_state.selected_thread = sorted_threads[0]['thread_id']
|
||||||
|
|
||||||
selected_thread = st.selectbox(
|
selected_thread = st.selectbox(
|
||||||
"🔍 Select Thread",
|
"🔍 Select Thread",
|
||||||
thread_options,
|
thread_options,
|
||||||
key="thread_select",
|
key="thread_select",
|
||||||
index=thread_options.index(str(st.session_state.selected_thread)) if st.session_state.selected_thread else 0
|
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_thread:
|
if selected_thread:
|
||||||
st.session_state.selected_thread = int(selected_thread)
|
st.session_state.selected_thread = selected_thread.split(' - ')[0] if ' - ' in selected_thread else selected_thread
|
||||||
|
else:
|
||||||
|
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")
|
|
@ -1,66 +1,16 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import requests
|
import requests
|
||||||
import json
|
from core.ui.utils import API_BASE_URL
|
||||||
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
from datetime import datetime
|
||||||
|
|
||||||
def display_thread_runner(thread_id):
|
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
|
||||||
st.write("## ⚙️ Run Thread")
|
return {
|
||||||
|
"system_message": {"role": "system", "content": system_message},
|
||||||
manual_config = display_manual_setup_tab()
|
|
||||||
|
|
||||||
# Common settings
|
|
||||||
additional_instructions = st.text_area("Additional Instructions", key="additional_instructions", height=100)
|
|
||||||
stream = st.checkbox("📡 Stream Responses", key="stream_responses")
|
|
||||||
|
|
||||||
# Prepare the run thread data
|
|
||||||
run_thread_data = prepare_run_thread_data(manual_config, additional_instructions, stream)
|
|
||||||
|
|
||||||
# Display the preview of the request payload in an expander
|
|
||||||
with st.expander("📤 Preview Request Payload", expanded=False):
|
|
||||||
st.json(run_thread_data)
|
|
||||||
|
|
||||||
# Center the run button and make it more prominent
|
|
||||||
col1, col2, col3 = st.columns([1, 2, 1])
|
|
||||||
with col2:
|
|
||||||
if st.button("▶️ Run Thread", key="run_thread_button", use_container_width=True):
|
|
||||||
run_thread(thread_id, run_thread_data)
|
|
||||||
|
|
||||||
display_thread_run_status(thread_id)
|
|
||||||
|
|
||||||
def display_manual_setup_tab():
|
|
||||||
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
|
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
with col1:
|
|
||||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
|
|
||||||
with col2:
|
|
||||||
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
|
|
||||||
|
|
||||||
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
|
|
||||||
|
|
||||||
tool_options = list(st.session_state.tools.keys())
|
|
||||||
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
|
|
||||||
|
|
||||||
manual_config = {
|
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"system_message": system_message,
|
"tools": selected_tools,
|
||||||
"selected_tools": selected_tools
|
"additional_system_message": additional_system_message
|
||||||
}
|
|
||||||
|
|
||||||
return manual_config
|
|
||||||
|
|
||||||
def prepare_run_thread_data(manual_config, additional_instructions, stream):
|
|
||||||
tools = [st.session_state.tools[tool]['schema'] for tool in manual_config['selected_tools'] if tool in st.session_state.tools]
|
|
||||||
return {
|
|
||||||
"system_message": {"role": "system", "content": manual_config['system_message']},
|
|
||||||
"model_name": manual_config['model_name'],
|
|
||||||
"temperature": manual_config['temperature'],
|
|
||||||
"max_tokens": manual_config['max_tokens'],
|
|
||||||
"tools": tools,
|
|
||||||
"additional_instructions": additional_instructions,
|
|
||||||
"stream": stream
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_thread(thread_id, run_thread_data):
|
def run_thread(thread_id, run_thread_data):
|
||||||
|
@ -73,49 +23,77 @@ def run_thread(thread_id, run_thread_data):
|
||||||
response_data = run_thread_response.json()
|
response_data = run_thread_response.json()
|
||||||
st.success("Thread run completed successfully!")
|
st.success("Thread run completed successfully!")
|
||||||
|
|
||||||
# Display the return payload in an expander
|
if 'id' in response_data:
|
||||||
with st.expander("📥 Return Payload", expanded=False):
|
st.session_state.latest_run_id = response_data['id']
|
||||||
st.json(response_data)
|
|
||||||
|
|
||||||
# Display the actual response content
|
st.subheader("Response Content")
|
||||||
st.write("### 📬 Response Content")
|
|
||||||
display_response_content(response_data)
|
display_response_content(response_data)
|
||||||
|
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.error("Failed to run thread.")
|
st.error(f"Failed to run thread. Status code: {run_thread_response.status_code}")
|
||||||
with st.expander("❌ Error Response", expanded=True):
|
st.text("Response content:")
|
||||||
st.json(run_thread_response.json())
|
st.text(run_thread_response.text)
|
||||||
|
|
||||||
def display_response_content(response_data):
|
def display_response_content(response_data):
|
||||||
if isinstance(response_data, dict) and 'response' in response_data:
|
if isinstance(response_data, dict) and 'choices' in response_data:
|
||||||
for item in response_data['response']:
|
message = response_data['choices'][0]['message']
|
||||||
if isinstance(item, dict):
|
st.write(f"**Role:** {message['role']}")
|
||||||
if 'content' in item:
|
st.write(f"**Content:** {message['content']}")
|
||||||
st.markdown(item['content'])
|
|
||||||
elif 'tool_calls' in item:
|
if 'tool_calls' in message:
|
||||||
st.write("**Tool Calls:**")
|
st.write("**Tool Calls:**")
|
||||||
for tool_call in item['tool_calls']:
|
for tool_call in message['tool_calls']:
|
||||||
st.write(f"- Function: `{tool_call['function']['name']}`")
|
st.write(f"- Function: `{tool_call['function']['name']}`")
|
||||||
st.code(tool_call['function']['arguments'], language="json")
|
st.code(tool_call['function']['arguments'], language="json")
|
||||||
elif isinstance(item, str):
|
|
||||||
st.markdown(item)
|
|
||||||
else:
|
else:
|
||||||
st.json(response_data)
|
st.json(response_data)
|
||||||
|
|
||||||
def display_thread_run_status(thread_id):
|
def fetch_thread_runs(thread_id, limit):
|
||||||
status_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/run/status/")
|
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
|
||||||
if status_response.status_code == 200:
|
if response.status_code == 200:
|
||||||
status_data = status_response.json()
|
return response.json()
|
||||||
st.write("### ⚙️ Thread Run Status")
|
|
||||||
status = status_data.get('status')
|
|
||||||
if status == 'completed':
|
|
||||||
st.success(f"**Status:** {status}")
|
|
||||||
elif status == 'error':
|
|
||||||
st.error(f"**Status:** {status}")
|
|
||||||
with st.expander("Error Details", expanded=True):
|
|
||||||
st.code(status_data.get('error_message'), language="")
|
|
||||||
else:
|
|
||||||
st.info(f"**Status:** {status}")
|
|
||||||
else:
|
else:
|
||||||
st.warning("Could not retrieve thread run status.")
|
st.error("Failed to retrieve runs.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def format_timestamp(timestamp):
|
||||||
|
if timestamp:
|
||||||
|
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
return 'N/A'
|
||||||
|
|
||||||
|
def display_runs(runs):
|
||||||
|
for run in runs:
|
||||||
|
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
|
||||||
|
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
|
||||||
|
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
|
||||||
|
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
|
||||||
|
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
|
||||||
|
with col2:
|
||||||
|
st.write(f"**Model:** {run['model']}")
|
||||||
|
st.write(f"**Temperature:** {run['temperature']}")
|
||||||
|
st.write(f"**Top P:** {run['top_p']}")
|
||||||
|
st.write(f"**Max Tokens:** {run['max_tokens']}")
|
||||||
|
st.write(f"**Tool Choice:** {run['tool_choice']}")
|
||||||
|
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
|
||||||
|
|
||||||
|
st.write("**System Message:**")
|
||||||
|
st.json(run['system_message'])
|
||||||
|
|
||||||
|
if run['tools']:
|
||||||
|
st.write("**Tools:**")
|
||||||
|
st.json(run['tools'])
|
||||||
|
|
||||||
|
if run['usage']:
|
||||||
|
st.write("**Usage:**")
|
||||||
|
st.json(run['usage'])
|
||||||
|
|
||||||
|
if run['response_format']:
|
||||||
|
st.write("**Response Format:**")
|
||||||
|
st.json(run['response_format'])
|
||||||
|
|
||||||
|
if run['last_error']:
|
||||||
|
st.error("**Last Error:**")
|
||||||
|
st.code(run['last_error'])
|
|
@ -4,13 +4,6 @@ from core.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||||
|
|
||||||
API_BASE_URL = "http://localhost:8000"
|
API_BASE_URL = "http://localhost:8000"
|
||||||
|
|
||||||
def fetch_agents():
|
|
||||||
response = requests.get(f"{API_BASE_URL}/agents/")
|
|
||||||
if response.status_code == 200:
|
|
||||||
st.session_state.agents = response.json()
|
|
||||||
else:
|
|
||||||
st.error("Failed to fetch agents.")
|
|
||||||
|
|
||||||
def fetch_tools():
|
def fetch_tools():
|
||||||
response = requests.get(f"{API_BASE_URL}/tools/")
|
response = requests.get(f"{API_BASE_URL}/tools/")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
@ -21,15 +14,10 @@ def fetch_tools():
|
||||||
def initialize_session_state():
|
def initialize_session_state():
|
||||||
if 'selected_thread' not in st.session_state:
|
if 'selected_thread' not in st.session_state:
|
||||||
st.session_state.selected_thread = None
|
st.session_state.selected_thread = None
|
||||||
if 'agents' not in st.session_state:
|
|
||||||
st.session_state.agents = []
|
|
||||||
if 'tools' not in st.session_state:
|
if 'tools' not in st.session_state:
|
||||||
st.session_state.tools = []
|
st.session_state.tools = []
|
||||||
if 'fetch_agents' not in st.session_state:
|
|
||||||
st.session_state.fetch_agents = fetch_agents
|
|
||||||
if 'fetch_tools' not in st.session_state:
|
if 'fetch_tools' not in st.session_state:
|
||||||
st.session_state.fetch_tools = fetch_tools
|
st.session_state.fetch_tools = fetch_tools
|
||||||
|
|
||||||
def fetch_data():
|
def fetch_data():
|
||||||
fetch_agents()
|
|
||||||
fetch_tools()
|
fetch_tools()
|
|
@ -417,6 +417,26 @@ files = [
|
||||||
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastapi"
|
||||||
|
version = "0.115.0"
|
||||||
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "fastapi-0.115.0-py3-none-any.whl", hash = "sha256:17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631"},
|
||||||
|
{file = "fastapi-0.115.0.tar.gz", hash = "sha256:f93b4ca3529a8ebc6fc3fcf710e5efa8de3df9b41570958abf1d97d843138004"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||||
|
starlette = ">=0.37.2,<0.39.0"
|
||||||
|
typing-extensions = ">=4.8.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filelock"
|
name = "filelock"
|
||||||
version = "3.16.1"
|
version = "3.16.1"
|
||||||
|
@ -2284,6 +2304,53 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||||
pymysql = ["pymysql"]
|
pymysql = ["pymysql"]
|
||||||
sqlcipher = ["sqlcipher3_binary"]
|
sqlcipher = ["sqlcipher3_binary"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sse-starlette"
|
||||||
|
version = "2.1.3"
|
||||||
|
description = "SSE plugin for Starlette"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "sse_starlette-2.1.3-py3-none-any.whl", hash = "sha256:8ec846438b4665b9e8c560fcdea6bc8081a3abf7942faa95e5a744999d219772"},
|
||||||
|
{file = "sse_starlette-2.1.3.tar.gz", hash = "sha256:9cd27eb35319e1414e3d2558ee7414487f9529ce3b3cf9b21434fd110e017169"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = "*"
|
||||||
|
starlette = "*"
|
||||||
|
uvicorn = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
examples = ["fastapi"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sseclient-py"
|
||||||
|
version = "1.7.2"
|
||||||
|
description = "SSE client for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "sseclient-py-1.7.2.tar.gz", hash = "sha256:ba3197d314766eccb72a1dda80b5fa14a0fbba07d796a287654c07edde88fe0f"},
|
||||||
|
{file = "sseclient_py-1.7.2-py2.py3-none-any.whl", hash = "sha256:a758653b13b78df42cdb696740635a26cb72ad433b75efb68dbbb163d099b6a9"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "starlette"
|
||||||
|
version = "0.38.6"
|
||||||
|
description = "The little ASGI library that shines."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05"},
|
||||||
|
{file = "starlette-0.38.6.tar.gz", hash = "sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.4.0,<5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "streamlit"
|
name = "streamlit"
|
||||||
version = "1.39.0"
|
version = "1.39.0"
|
||||||
|
@ -2615,6 +2682,24 @@ h2 = ["h2 (>=4,<5)"]
|
||||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uvicorn"
|
||||||
|
version = "0.31.0"
|
||||||
|
description = "The lightning-fast ASGI server."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"},
|
||||||
|
{file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=7.0"
|
||||||
|
h11 = ">=0.8"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "watchdog"
|
name = "watchdog"
|
||||||
version = "5.0.3"
|
version = "5.0.3"
|
||||||
|
@ -2784,4 +2869,4 @@ type = ["pytest-mypy"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "817f3d510c67d6a59033e26f11a247490013e3e63c6184844d6b261f43352cfd"
|
content-hash = "11cd10192810fd94515b1f6ad5d8ce27b5d665d1d4e65fff9b05b4cd80c9d377"
|
||||||
|
|
|
@ -21,6 +21,9 @@ litellm = "^1.44.4"
|
||||||
pytest = "^8.3.2"
|
pytest = "^8.3.2"
|
||||||
pytest-asyncio = "^0.24.0"
|
pytest-asyncio = "^0.24.0"
|
||||||
agentops = "^0.3.10"
|
agentops = "^0.3.10"
|
||||||
|
sseclient-py = "1.7.2"
|
||||||
|
fastapi = "^0.115.0"
|
||||||
|
sse-starlette = "^2.1.3"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from core.tools.tool import Tool, ToolResult
|
from core.tool import Tool, ToolResult
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
|
|
||||||
class FilesTool(Tool):
|
class FilesTool(Tool):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from core.tools.tool import Tool, ToolResult
|
from core.tool import Tool, ToolResult
|
||||||
|
|
||||||
class ExampleTool(Tool):
|
class ExampleTool(Tool):
|
||||||
description = "An example tool for demonstration purposes."
|
description = "An example tool for demonstration purposes."
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Hello there! This is a sample file created just for you.
|
|
@ -0,0 +1 @@
|
||||||
|
This file contains random information about this assistant. I'm here to help you with a wide range of tasks, from answering questions to managing files.
|
|
@ -1 +0,0 @@
|
||||||
This is some random content for the file.
|
|
|
@ -1 +0,0 @@
|
||||||
random contents
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
Here are some random tips:
|
||||||
|
1. Stay curious and keep learning.
|
||||||
|
2. Organize your workspace for better productivity.
|
||||||
|
3. Take breaks to maintain focus and energy.
|
Loading…
Reference in New Issue