mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
a933193851
commit
6c903fa761
|
@ -1,58 +0,0 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from core.db import Agent
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import json
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
async def create_agent(self, model: str, name: str, system_prompt: str, selected_tools: List[str], temperature: float = 0.5) -> int:
|
||||
async with self.db.get_async_session() as session:
|
||||
new_agent = Agent(
|
||||
model=model,
|
||||
name=name,
|
||||
system_prompt=system_prompt,
|
||||
selected_tools=selected_tools, # Store as a list directly
|
||||
temperature=temperature,
|
||||
created_at=datetime.now().isoformat()
|
||||
)
|
||||
session.add(new_agent)
|
||||
await session.commit()
|
||||
await session.refresh(new_agent)
|
||||
return new_agent.id
|
||||
|
||||
async def get_agent(self, agent_id: int) -> Optional[Agent]:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
return agent
|
||||
|
||||
async def update_agent(self, agent_id: int, **kwargs) -> bool:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
if agent:
|
||||
for key, value in kwargs.items():
|
||||
setattr(agent, key, value)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def delete_agent(self, agent_id: int) -> bool:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
if agent:
|
||||
await session.delete(agent)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def list_agents(self) -> List[Agent]:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(Agent))
|
||||
agents = result.scalars().all()
|
||||
return agents
|
231
core/api.py
231
core/api.py
|
@ -1,109 +1,107 @@
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, HTTPException, Query, Path
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
import asyncio
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from core.db import Database
|
||||
from core.thread_manager import ThreadManager
|
||||
from core.agent_manager import AgentManager
|
||||
from core.tools.tool_registry import ToolRegistry
|
||||
from core.tool_registry import ToolRegistry
|
||||
from core.config import Settings
|
||||
|
||||
app = FastAPI()
|
||||
app = FastAPI(
|
||||
title="Thread Manager API",
|
||||
description="API for managing and running threads with LLM integration",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
db = Database()
|
||||
manager = ThreadManager(db)
|
||||
tool_registry = ToolRegistry() # Initialize here
|
||||
agent_manager = AgentManager(db)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Pydantic models for request and response bodies
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
|
||||
content: str = Field(..., description="The content of the message")
|
||||
|
||||
class RunThreadRequest(BaseModel):
|
||||
system_message: Dict[str, Any]
|
||||
model_name: str
|
||||
temperature: float = 0.5
|
||||
max_tokens: Optional[int] = 500
|
||||
tools: Optional[List[str]] = None
|
||||
tool_choice: str = "required"
|
||||
additional_instructions: Optional[str] = None
|
||||
stream: bool = False
|
||||
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
|
||||
model_name: str = Field(..., description="The name of the LLM model to be used")
|
||||
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
|
||||
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
|
||||
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
|
||||
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
|
||||
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended")
|
||||
stream: bool = Field(False, description="Whether to stream the response")
|
||||
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
|
||||
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
|
||||
|
||||
class Agent(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
system_prompt: str
|
||||
selected_tools: List[str]
|
||||
temperature: float = 0.5
|
||||
|
||||
@app.post("/threads/")
|
||||
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
|
||||
async def create_thread():
|
||||
"""
|
||||
Create a new thread and return its ID.
|
||||
"""
|
||||
thread_id = await manager.create_thread()
|
||||
return {"thread_id": thread_id}
|
||||
|
||||
@app.get("/threads/")
|
||||
@app.get("/threads/", response_model=List[Dict[str, str]], summary="Get all threads")
|
||||
async def get_threads():
|
||||
"""
|
||||
Retrieve a list of all thread IDs.
|
||||
"""
|
||||
threads = await manager.get_threads()
|
||||
return [{"thread_id": thread.thread_id} for thread in threads]
|
||||
|
||||
@app.post("/threads/{thread_id}/messages/")
|
||||
async def add_message(thread_id: int, message: Message):
|
||||
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
|
||||
async def add_message(thread_id: str, message: Message):
|
||||
"""
|
||||
Add a new message to the specified thread.
|
||||
"""
|
||||
await manager.add_message(thread_id, message.dict())
|
||||
return {"status": "success"}
|
||||
|
||||
@app.get("/threads/{thread_id}/messages/")
|
||||
async def list_messages(thread_id: int):
|
||||
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
|
||||
async def list_messages(thread_id: str):
|
||||
"""
|
||||
Retrieve all messages from the specified thread.
|
||||
"""
|
||||
messages = await manager.list_messages(thread_id)
|
||||
return messages
|
||||
|
||||
@app.post("/threads/{thread_id}/run/")
|
||||
async def run_thread(thread_id: int, request: Dict[str, Any]):
|
||||
if 'agent_id' in request:
|
||||
# Agent-based run
|
||||
response_gen = manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
agent_id=request['agent_id'],
|
||||
additional_instructions=request.get('additional_instructions'),
|
||||
stream=request.get('stream', False)
|
||||
)
|
||||
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
|
||||
async def run_thread(thread_id: str, request: RunThreadRequest):
|
||||
"""
|
||||
Run the specified thread with the given parameters.
|
||||
"""
|
||||
response = await manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_message=request.system_message,
|
||||
model_name=request.model_name,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
tools=request.tools,
|
||||
additional_system_message=request.additional_system_message,
|
||||
top_p=request.top_p,
|
||||
tool_choice=request.tool_choice,
|
||||
response_format=request.response_format)
|
||||
|
||||
return {"status": "success", "response": response}
|
||||
|
||||
@app.get("/threads/{thread_id}/run/status/", response_model=Dict[str, Any], summary="Get thread run status")
|
||||
async def get_thread_run_status(thread_id: str):
|
||||
"""
|
||||
Retrieve the status of the latest run for the specified thread.
|
||||
"""
|
||||
latest_thread_run = await manager.get_latest_thread_run(thread_id)
|
||||
if latest_thread_run:
|
||||
return latest_thread_run
|
||||
else:
|
||||
# Manual configuration run
|
||||
response_gen = manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_message=request['system_message'],
|
||||
model_name=request['model_name'],
|
||||
temperature=request['temperature'],
|
||||
max_tokens=request['max_tokens'],
|
||||
tools=request['tools'],
|
||||
additional_instructions=request.get('additional_instructions'),
|
||||
stream=request.get('stream', False)
|
||||
)
|
||||
return {"status": "No runs found for this thread."}
|
||||
|
||||
if request.get('stream', False):
|
||||
raise HTTPException(status_code=501, detail="Streaming is not supported via this endpoint.")
|
||||
else:
|
||||
response = []
|
||||
async for chunk in response_gen:
|
||||
response.append(chunk)
|
||||
return {"response": response}
|
||||
|
||||
@app.get("/threads/{thread_id}/run/status/")
|
||||
async def get_thread_run_status(thread_id: int):
|
||||
try:
|
||||
latest_thread_run = await manager.get_latest_thread_run(thread_id)
|
||||
if latest_thread_run:
|
||||
return {
|
||||
"status": latest_thread_run.status,
|
||||
"error_message": latest_thread_run.error_message
|
||||
}
|
||||
else:
|
||||
return {"status": "No runs found for this thread."}
|
||||
except AttributeError:
|
||||
return {"status": "Error", "message": "Unable to retrieve thread run status."}
|
||||
|
||||
@app.get("/tools/")
|
||||
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
|
||||
async def get_tools():
|
||||
"""
|
||||
Retrieve a list of all available tools and their schemas.
|
||||
"""
|
||||
tools = tool_registry.get_all_tools()
|
||||
if not tools:
|
||||
print("No tools found in the registry") # Debug print
|
||||
|
@ -116,52 +114,45 @@ async def get_tools():
|
|||
for name, tool_info in tools.items()
|
||||
}
|
||||
|
||||
@app.post("/agents/")
|
||||
async def create_agent(agent: Agent):
|
||||
agent_id = await agent_manager.create_agent(**agent.dict())
|
||||
return {"agent_id": agent_id}
|
||||
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
|
||||
async def get_run(
|
||||
thread_id: str = Path(..., description="The ID of the thread that was run"),
|
||||
run_id: str = Path(..., description="The ID of the run to retrieve")
|
||||
):
|
||||
"""
|
||||
Retrieve the run object matching the specified ID.
|
||||
"""
|
||||
run = await manager.get_run(thread_id, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return run
|
||||
|
||||
@app.get("/agents/")
|
||||
async def list_agents():
|
||||
agents = await agent_manager.list_agents()
|
||||
return [
|
||||
{
|
||||
"id": agent.id,
|
||||
"name": agent.name,
|
||||
"model": agent.model,
|
||||
"system_prompt": agent.system_prompt,
|
||||
"selected_tools": agent.selected_tools,
|
||||
"temperature": agent.temperature,
|
||||
"created_at": agent.created_at
|
||||
}
|
||||
for agent in agents
|
||||
]
|
||||
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
|
||||
async def cancel_run(
|
||||
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
|
||||
run_id: str = Path(..., description="The ID of the run to cancel")
|
||||
):
|
||||
"""
|
||||
Cancels a run that is in_progress.
|
||||
"""
|
||||
run = await manager.cancel_run(thread_id, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return run
|
||||
|
||||
@app.get("/agents/{agent_id}")
|
||||
async def get_agent(agent_id: int):
|
||||
agent = await agent_manager.get_agent(agent_id)
|
||||
if agent:
|
||||
return {
|
||||
"id": agent.id,
|
||||
"name": agent.name,
|
||||
"model": agent.model,
|
||||
"system_prompt": agent.system_prompt,
|
||||
"selected_tools": agent.selected_tools,
|
||||
"temperature": agent.temperature,
|
||||
"created_at": agent.created_at
|
||||
}
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
|
||||
async def list_runs(
|
||||
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
|
||||
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
|
||||
):
|
||||
"""
|
||||
Returns a list of runs belonging to a thread.
|
||||
"""
|
||||
runs = await manager.list_runs(thread_id, limit)
|
||||
return runs
|
||||
|
||||
@app.put("/agents/{agent_id}")
|
||||
async def update_agent(agent_id: int, agent: Agent):
|
||||
success = await agent_manager.update_agent(agent_id, **agent.dict())
|
||||
if success:
|
||||
return {"status": "success"}
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
# Add more endpoints as needed for production use
|
||||
|
||||
@app.delete("/agents/{agent_id}")
|
||||
async def delete_agent(agent_id: int):
|
||||
success = await agent_manager.delete_agent(agent_id)
|
||||
if success:
|
||||
return {"status": "success"}
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
44
core/db.py
44
core/db.py
|
@ -1,17 +1,19 @@
|
|||
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON
|
||||
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from core.config import settings # Changed from Settings to settings
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Thread(Base):
|
||||
__tablename__ = 'threads'
|
||||
|
||||
thread_id = Column(Integer, primary_key=True)
|
||||
thread_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
messages = Column(Text)
|
||||
creation_date = Column(String)
|
||||
last_updated_date = Column(String)
|
||||
|
@ -22,25 +24,31 @@ class Thread(Base):
|
|||
class ThreadRun(Base):
|
||||
__tablename__ = 'thread_runs'
|
||||
|
||||
run_id = Column(Integer, primary_key=True)
|
||||
thread_id = Column(Integer, ForeignKey('threads.thread_id'))
|
||||
messages = Column(Text)
|
||||
creation_date = Column(String)
|
||||
status = Column(String)
|
||||
error_message = Column(Text, nullable=True)
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
thread_id = Column(String(36), ForeignKey('threads.thread_id'))
|
||||
created_at = Column(Integer)
|
||||
status = Column(String)
|
||||
last_error = Column(Text, nullable=True)
|
||||
started_at = Column(Integer, nullable=True)
|
||||
cancelled_at = Column(Integer, nullable=True)
|
||||
failed_at = Column(Integer, nullable=True)
|
||||
completed_at = Column(Integer, nullable=True)
|
||||
model = Column(String)
|
||||
system_message = Column(Text)
|
||||
tools = Column(JSON, nullable=True)
|
||||
usage = Column(JSON, nullable=True)
|
||||
temperature = Column(Float, nullable=True)
|
||||
top_p = Column(Float, nullable=True)
|
||||
max_tokens = Column(Integer, nullable=True)
|
||||
tool_choice = Column(String, nullable=True)
|
||||
execute_tools_async = Column(Boolean)
|
||||
response_format = Column(JSON, nullable=True)
|
||||
|
||||
thread = relationship("Thread", back_populates="thread_runs")
|
||||
|
||||
class Agent(Base):
|
||||
__tablename__ = 'agents'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
model = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
selected_tools = Column(JSON) # Changed from ARRAY to JSON
|
||||
temperature = Column(Float, default=0.5)
|
||||
created_at = Column(String, nullable=False)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.created_at = int(datetime.utcnow().timestamp())
|
||||
|
||||
# class MemoryModule(Base):
|
||||
# __tablename__ = 'memory_modules'
|
||||
|
|
67
core/llm.py
67
core/llm.py
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, AsyncGenerator
|
||||
from typing import Union, Dict, Any
|
||||
import litellm
|
||||
import os
|
||||
import json
|
||||
|
@ -26,24 +26,13 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False) -> AsyncGenerator[Union[dict, str], None]:
|
||||
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def attempt_api_call(api_call_func, max_attempts=3):
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await api_call_func()
|
||||
if stream:
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
else:
|
||||
response_content = response.choices[0].message['content'] if json_mode else response
|
||||
if json_mode:
|
||||
if not json.loads(response_content):
|
||||
logger.info(f"Invalid JSON received, retrying attempt {attempt + 1}")
|
||||
continue
|
||||
yield response
|
||||
return
|
||||
return await api_call_func()
|
||||
except litellm.exceptions.RateLimitError as e:
|
||||
logger.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
|
||||
await asyncio.sleep(30)
|
||||
|
@ -60,7 +49,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
|||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": {"type": "json_object"} if json_mode else None,
|
||||
"response_format": response_format or ({"type": "json_object"} if json_mode else None),
|
||||
"top_p": top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add api_key and api_base if provided
|
||||
|
@ -88,12 +79,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
|||
api_call_params["tool_choice"] = tool_choice
|
||||
|
||||
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
||||
# if messages[0]["role"] != "user":
|
||||
# api_call_params["messages"] = [{"role": "user", "content": "."}] + messages
|
||||
api_call_params["extra_headers"] = {
|
||||
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
||||
}
|
||||
# # "anthropic-beta": "prompt-caching-2024-07-31"
|
||||
|
||||
# Log the API request
|
||||
logger.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
|
||||
|
@ -101,20 +89,49 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
|||
if agentops_session:
|
||||
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
|
||||
else:
|
||||
if stream:
|
||||
response = await litellm.acompletion(**api_call_params, stream=True)
|
||||
else:
|
||||
response = await litellm.acompletion(**api_call_params)
|
||||
response = await litellm.acompletion(**api_call_params)
|
||||
|
||||
# Log the API response
|
||||
logger.info(f"Received API response: {response}")
|
||||
|
||||
return response
|
||||
|
||||
async for result in attempt_api_call(api_call):
|
||||
yield result
|
||||
return await attempt_api_call(api_call)
|
||||
|
||||
# Sample Usage
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
async def test_llm_api_call(stream=True):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Complex essay on economics"}
|
||||
]
|
||||
model_name = "gpt-4o"
|
||||
|
||||
pass
|
||||
response = await make_llm_api_call(messages, model_name, stream=stream)
|
||||
|
||||
if stream:
|
||||
print("Streaming response:")
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, dict) and 'choices' in chunk:
|
||||
content = chunk['choices'][0]['delta'].get('content', '')
|
||||
print(content, end='', flush=True)
|
||||
else:
|
||||
# For non-dict responses (like ModelResponse objects)
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
print(content, end='', flush=True)
|
||||
print("\nStream completed.")
|
||||
else:
|
||||
print("Non-streaming response:")
|
||||
if isinstance(response, dict) and 'choices' in response:
|
||||
print(response['choices'][0]['message']['content'])
|
||||
else:
|
||||
# For non-dict responses (like ModelResponse objects)
|
||||
print(response.choices[0].message.content)
|
||||
|
||||
# Example usage:
|
||||
# asyncio.run(test_llm_api_call(stream=True)) # For streaming
|
||||
# asyncio.run(test_llm_api_call(stream=False)) # For non-streaming
|
||||
|
||||
asyncio.run(test_llm_api_call())
|
|
@ -5,19 +5,18 @@ from typing import List, Dict, Any, Optional, Callable, AsyncGenerator, Union
|
|||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from core.db import Database, Thread, ThreadRun
|
||||
from core.tools.tool import Tool, ToolResult
|
||||
from core.tool import Tool, ToolResult
|
||||
from core.llm import make_llm_api_call
|
||||
# from core.working_memory_manager import WorkingMemory
|
||||
from datetime import datetime
|
||||
from core.tools.tool_registry import ToolRegistry
|
||||
from core.tool_registry import ToolRegistry
|
||||
import re
|
||||
from core.agent_manager import AgentManager
|
||||
import uuid
|
||||
|
||||
class ThreadManager:
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.agent_manager = AgentManager(db)
|
||||
|
||||
async def create_thread(self) -> int:
|
||||
async with self.db.get_async_session() as session:
|
||||
|
@ -197,84 +196,176 @@ class ThreadManager:
|
|||
|
||||
async def run_thread(
|
||||
self,
|
||||
thread_id: int,
|
||||
agent_id: Optional[int] = None,
|
||||
system_message: Optional[Dict[str, Any]] = None,
|
||||
model_name: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thread_id: str,
|
||||
system_message: Dict[str, Any],
|
||||
model_name: str,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
additional_instructions: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
additional_system_message: Optional[str] = None,
|
||||
hide_tool_msgs: bool = False,
|
||||
execute_tools_async: bool = True,
|
||||
use_tool_parser: bool = False,
|
||||
stream: bool = False
|
||||
) -> AsyncGenerator[Union[Dict[str, Any], str], None]:
|
||||
if agent_id is not None:
|
||||
agent = await self.agent_manager.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise ValueError(f"Agent with id {agent_id} not found")
|
||||
system_message = {"role": "system", "content": agent.system_prompt}
|
||||
model_name = agent.model
|
||||
temperature = agent.temperature
|
||||
tools = [self.tool_registry.get_tool(tool).schema()[0] for tool in agent.selected_tools] if agent.selected_tools else None
|
||||
elif system_message is None or model_name is None:
|
||||
raise ValueError("Either agent_id or system_message and model_name must be provided")
|
||||
top_p: Optional[float] = None,
|
||||
tool_choice: str = "auto",
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
# Fetch full tool objects based on the provided tool names
|
||||
full_tools = None
|
||||
if tools:
|
||||
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in tools if self.tool_registry.get_tool(tool_name)]
|
||||
|
||||
thread_run = ThreadRun(
|
||||
id=run_id,
|
||||
thread_id=thread_id,
|
||||
status="queued",
|
||||
model=model_name,
|
||||
system_message=json.dumps(system_message),
|
||||
tools=json.dumps(full_tools) if full_tools else None,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
tool_choice=tool_choice,
|
||||
execute_tools_async=execute_tools_async,
|
||||
response_format=json.dumps(response_format) if response_format else None
|
||||
)
|
||||
|
||||
if await self.should_stop(thread_id):
|
||||
yield {"status": "stopped", "message": "Session cancelled"}
|
||||
return
|
||||
async with self.db.get_async_session() as session:
|
||||
session.add(thread_run)
|
||||
await session.commit()
|
||||
|
||||
if use_tool_parser:
|
||||
hide_tool_msgs = True
|
||||
|
||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||
|
||||
# Prepare messages
|
||||
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
|
||||
prepared_messages = [system_message] + messages
|
||||
|
||||
if additional_instructions:
|
||||
additional_instruction_message = {
|
||||
"role": "user",
|
||||
"content": additional_instructions
|
||||
}
|
||||
prepared_messages.append(additional_instruction_message)
|
||||
|
||||
try:
|
||||
response_stream = make_llm_api_call(
|
||||
thread_run.status = "in_progress"
|
||||
thread_run.started_at = int(datetime.utcnow().timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
if await self.should_stop(thread_id):
|
||||
thread_run.status = "cancelled"
|
||||
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
return {"status": "stopped", "message": "Session cancelled"}
|
||||
|
||||
if use_tool_parser:
|
||||
hide_tool_msgs = True
|
||||
|
||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||
|
||||
# Prepare messages
|
||||
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
|
||||
prepared_messages = [system_message] + messages
|
||||
|
||||
if additional_system_message:
|
||||
additional_instruction_message = {
|
||||
"role": "user",
|
||||
"content": additional_system_message
|
||||
}
|
||||
prepared_messages.append(additional_instruction_message)
|
||||
|
||||
response = await make_llm_api_call(
|
||||
prepared_messages,
|
||||
model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=stream
|
||||
tools=full_tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
top_p=top_p,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
async for partial_response in response_stream:
|
||||
if stream:
|
||||
yield partial_response
|
||||
else:
|
||||
response = partial_response
|
||||
usage = response.usage if hasattr(response, 'usage') else None
|
||||
usage_dict = self.serialize_usage(usage) if usage else None
|
||||
thread_run.usage = usage_dict
|
||||
|
||||
if not stream:
|
||||
if tools is None or use_tool_parser:
|
||||
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
|
||||
else:
|
||||
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
|
||||
# Add the assistant's message to the thread
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message['content']
|
||||
}
|
||||
if 'tool_calls' in response.choices[0].message:
|
||||
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
|
||||
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
|
||||
if await self.should_stop(thread_id):
|
||||
yield {"status": "stopped", "message": "Session cancelled"}
|
||||
else:
|
||||
await self.save_thread_run(thread_id)
|
||||
yield response
|
||||
if tools is None or use_tool_parser:
|
||||
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
|
||||
else:
|
||||
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
|
||||
|
||||
thread_run.status = "completed"
|
||||
thread_run.completed_at = int(datetime.utcnow().timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
return {
|
||||
"id": thread_run.id,
|
||||
"choices": [self.serialize_choice(choice) for choice in response.choices],
|
||||
"usage": usage_dict,
|
||||
"model": model_name,
|
||||
"object": "chat.completion",
|
||||
"created": int(datetime.utcnow().timestamp())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error in API call: {str(e)}\n\nFull error: {repr(e)}"
|
||||
logging.error(error_message)
|
||||
await self.update_thread_run_with_error(thread_id, error_message)
|
||||
yield {"status": "error", "message": error_message}
|
||||
thread_run.status = "failed"
|
||||
thread_run.failed_at = int(datetime.utcnow().timestamp())
|
||||
thread_run.last_error = error_message
|
||||
await self.update_thread_run(thread_run)
|
||||
return {"status": "error", "message": error_message}
|
||||
|
||||
def serialize_usage(self, usage):
|
||||
return {
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
|
||||
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
|
||||
}
|
||||
|
||||
def serialize_completion_tokens_details(self, details):
|
||||
return {
|
||||
"audio_tokens": details.audio_tokens,
|
||||
"reasoning_tokens": details.reasoning_tokens
|
||||
}
|
||||
|
||||
def serialize_prompt_tokens_details(self, details):
|
||||
return {
|
||||
"audio_tokens": details.audio_tokens,
|
||||
"cached_tokens": details.cached_tokens
|
||||
}
|
||||
|
||||
def serialize_choice(self, choice):
|
||||
return {
|
||||
"finish_reason": choice.finish_reason,
|
||||
"index": choice.index,
|
||||
"message": self.serialize_message(choice.message)
|
||||
}
|
||||
|
||||
def serialize_message(self, message):
|
||||
return {
|
||||
"content": message.content,
|
||||
"role": message.role,
|
||||
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
|
||||
}
|
||||
|
||||
def serialize_tool_call(self, tool_call):
|
||||
return {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
|
||||
async def update_thread_run(self, thread_run: ThreadRun):
|
||||
async with self.db.get_async_session() as session:
|
||||
session.add(thread_run)
|
||||
await session.commit()
|
||||
await session.refresh(thread_run)
|
||||
|
||||
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
|
||||
response_content = response.choices[0].message['content']
|
||||
|
@ -282,8 +373,8 @@ class ThreadManager:
|
|||
if use_tool_parser:
|
||||
await self.handle_tool_parser_response(thread_id, response_content)
|
||||
else:
|
||||
logging.info("Adding assistant message to thread.")
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||
# The message has already been added in the run_thread method, so we don't need to add it again here
|
||||
pass
|
||||
|
||||
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
|
||||
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
|
||||
|
@ -325,9 +416,8 @@ class ThreadManager:
|
|||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.get('tool_calls', [])
|
||||
|
||||
assistant_message = self.create_assistant_message_with_tools(response_message)
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
|
||||
# The assistant message has already been added in the run_thread method
|
||||
|
||||
available_functions = self.get_available_functions()
|
||||
|
||||
if await self.should_stop(thread_id):
|
||||
|
@ -348,9 +438,7 @@ class ThreadManager:
|
|||
|
||||
except AttributeError as e:
|
||||
logging.error(f"AttributeError: {e}")
|
||||
content = response_message.get('content', '')
|
||||
if content:
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": content})
|
||||
# No need to add the message here as it's already been added in the run_thread method
|
||||
|
||||
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
||||
message = {
|
||||
|
@ -452,7 +540,7 @@ class ThreadManager:
|
|||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def save_thread_run(self, thread_id: int):
|
||||
async def save_thread_run(self, thread_id: str):
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
if not thread:
|
||||
|
@ -461,14 +549,26 @@ class ThreadManager:
|
|||
messages = json.loads(thread.messages)
|
||||
creation_date = datetime.now().isoformat()
|
||||
|
||||
new_thread_run = ThreadRun(
|
||||
thread_id=thread_id,
|
||||
messages=json.dumps(messages),
|
||||
creation_date=creation_date,
|
||||
status='completed'
|
||||
)
|
||||
session.add(new_thread_run)
|
||||
await session.commit()
|
||||
# Get the latest ThreadRun for this thread
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
latest_thread_run = result.scalar_one_or_none()
|
||||
|
||||
if latest_thread_run:
|
||||
# Update the existing ThreadRun
|
||||
latest_thread_run.messages = json.dumps(messages)
|
||||
latest_thread_run.last_updated_date = creation_date
|
||||
await session.commit()
|
||||
else:
|
||||
# Create a new ThreadRun if none exists
|
||||
new_thread_run = ThreadRun(
|
||||
thread_id=thread_id,
|
||||
messages=json.dumps(messages),
|
||||
creation_date=creation_date,
|
||||
status='completed'
|
||||
)
|
||||
session.add(new_thread_run)
|
||||
await session.commit()
|
||||
|
||||
async def get_thread(self, thread_id: int) -> Optional[Thread]:
|
||||
async with self.db.get_async_session() as session:
|
||||
|
@ -489,11 +589,92 @@ class ThreadManager:
|
|||
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_latest_thread_run(self, thread_id: int):
|
||||
async def get_latest_thread_run(self, thread_id: str):
|
||||
async with self.db.get_async_session() as session:
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
latest_run = result.scalar_one_or_none()
|
||||
if latest_run:
|
||||
return {
|
||||
"id": latest_run.id,
|
||||
"status": latest_run.status,
|
||||
"error_message": latest_run.last_error,
|
||||
"created_at": latest_run.created_at,
|
||||
"started_at": latest_run.started_at,
|
||||
"completed_at": latest_run.completed_at,
|
||||
"cancelled_at": latest_run.cancelled_at,
|
||||
"failed_at": latest_run.failed_at,
|
||||
"model": latest_run.model,
|
||||
"usage": latest_run.usage
|
||||
}
|
||||
return None
|
||||
|
||||
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id:
|
||||
return {
|
||||
"id": run.id,
|
||||
"thread_id": run.thread_id,
|
||||
"status": run.status,
|
||||
"created_at": run.created_at,
|
||||
"started_at": run.started_at,
|
||||
"completed_at": run.completed_at,
|
||||
"cancelled_at": run.cancelled_at,
|
||||
"failed_at": run.failed_at,
|
||||
"model": run.model,
|
||||
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||
"tools": json.loads(run.tools) if run.tools else None,
|
||||
"usage": run.usage,
|
||||
"temperature": run.temperature,
|
||||
"top_p": run.top_p,
|
||||
"max_tokens": run.max_tokens,
|
||||
"tool_choice": run.tool_choice,
|
||||
"execute_tools_async": run.execute_tools_async,
|
||||
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||
"last_error": run.last_error
|
||||
}
|
||||
return None
|
||||
|
||||
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id and run.status == "in_progress":
|
||||
run.status = "cancelled"
|
||||
run.cancelled_at = int(datetime.utcnow().timestamp())
|
||||
await session.commit()
|
||||
return await self.get_run(thread_id, run_id)
|
||||
return None
|
||||
|
||||
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
|
||||
result = await session.execute(stmt)
|
||||
runs = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": run.id,
|
||||
"thread_id": run.thread_id,
|
||||
"status": run.status,
|
||||
"created_at": run.created_at,
|
||||
"started_at": run.started_at,
|
||||
"completed_at": run.completed_at,
|
||||
"cancelled_at": run.cancelled_at,
|
||||
"failed_at": run.failed_at,
|
||||
"model": run.model,
|
||||
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||
"tools": json.loads(run.tools) if run.tools else None,
|
||||
"usage": run.usage,
|
||||
"temperature": run.temperature,
|
||||
"top_p": run.top_p,
|
||||
"max_tokens": run.max_tokens,
|
||||
"tool_choice": run.tool_choice,
|
||||
"execute_tools_async": run.execute_tools_async,
|
||||
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||
"last_error": run.last_error
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
@ -519,6 +700,4 @@ if __name__ == "__main__":
|
|||
|
||||
print(f"Response: {response}")
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
asyncio.run(main())
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, Type, Any
|
||||
from core.tools.tool import Tool
|
||||
from typing import Dict, Type, Any, Optional
|
||||
from core.tool import Tool
|
||||
from core.config import settings
|
||||
import importlib.util
|
||||
import os
|
||||
|
@ -39,7 +39,7 @@ class ToolRegistry:
|
|||
|
||||
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
|
||||
|
||||
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
||||
def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
return self.tools.get(tool_name)
|
||||
|
||||
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
|
|
@ -1,97 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||
|
||||
def display_agent_management():
|
||||
st.header("Agent Management")
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
with col1:
|
||||
display_create_agent_form()
|
||||
|
||||
with col2:
|
||||
display_existing_agents()
|
||||
|
||||
def display_create_agent_form():
|
||||
st.subheader("Create New Agent")
|
||||
with st.form("create_agent_form"):
|
||||
new_agent_name = st.text_input("Agent Name")
|
||||
new_agent_model = st.selectbox("Model", AI_MODELS)
|
||||
new_agent_system_prompt = st.text_area("System Prompt", value=STANDARD_SYSTEM_MESSAGE)
|
||||
new_agent_temperature = st.slider("Temperature", 0.0, 1.0, 0.5)
|
||||
tool_options = list(st.session_state.tools.keys())
|
||||
new_agent_tools = st.multiselect("Tools", tool_options)
|
||||
|
||||
submitted = st.form_submit_button("Create Agent")
|
||||
if submitted:
|
||||
create_agent(new_agent_name, new_agent_model, new_agent_system_prompt, new_agent_temperature, new_agent_tools)
|
||||
|
||||
def create_agent(name, model, system_prompt, temperature, tools):
|
||||
response = requests.post(f"{API_BASE_URL}/agents/", json={
|
||||
"name": name,
|
||||
"model": model,
|
||||
"system_prompt": system_prompt,
|
||||
"temperature": temperature,
|
||||
"selected_tools": tools
|
||||
})
|
||||
if response.status_code == 200:
|
||||
st.success("Agent created successfully!")
|
||||
st.session_state.fetch_agents()
|
||||
else:
|
||||
st.error("Failed to create agent.")
|
||||
|
||||
def display_existing_agents():
|
||||
st.subheader("Existing Agents")
|
||||
for agent in st.session_state.agents:
|
||||
with st.expander(f"Agent: {agent['name']}"):
|
||||
display_agent_details(agent)
|
||||
|
||||
def display_agent_details(agent):
|
||||
st.write(f"Model: {agent['model']}")
|
||||
st.write(f"Temperature: {agent['temperature']}")
|
||||
st.write(f"Tools: {', '.join(agent['selected_tools'])}")
|
||||
|
||||
if st.button(f"Edit Agent {agent['id']}"):
|
||||
st.session_state.editing_agent = agent['id']
|
||||
|
||||
if st.button(f"Delete Agent {agent['id']}"):
|
||||
delete_agent(agent['id'])
|
||||
|
||||
if st.session_state.get('editing_agent') == agent['id']:
|
||||
edit_agent_form(agent)
|
||||
|
||||
def edit_agent_form(agent):
|
||||
with st.form(f"edit_agent_form_{agent['id']}"):
|
||||
updated_name = st.text_input("Agent Name", value=agent['name'])
|
||||
updated_model = st.selectbox("Model", AI_MODELS, index=AI_MODELS.index(agent['model']))
|
||||
updated_system_prompt = st.text_area("System Prompt", value=agent['system_prompt'])
|
||||
updated_temperature = st.slider("Temperature", 0.0, 1.0, value=agent['temperature'])
|
||||
tool_options = list(st.session_state.tools.keys())
|
||||
updated_tools = st.multiselect("Tools", options=tool_options, default=agent['selected_tools'])
|
||||
|
||||
if st.form_submit_button("Update Agent"):
|
||||
update_agent(agent['id'], updated_name, updated_model, updated_system_prompt, updated_temperature, updated_tools)
|
||||
st.session_state.editing_agent = None
|
||||
|
||||
def update_agent(agent_id, name, model, system_prompt, temperature, tools):
|
||||
response = requests.put(f"{API_BASE_URL}/agents/{agent_id}", json={
|
||||
"name": name,
|
||||
"model": model,
|
||||
"system_prompt": system_prompt,
|
||||
"temperature": temperature,
|
||||
"selected_tools": tools
|
||||
})
|
||||
if response.status_code == 200:
|
||||
st.success("Agent updated successfully!")
|
||||
st.session_state.fetch_agents()
|
||||
else:
|
||||
st.error("Failed to update agent.")
|
||||
|
||||
def delete_agent(agent_id):
|
||||
response = requests.delete(f"{API_BASE_URL}/agents/{agent_id}")
|
||||
if response.status_code == 200:
|
||||
st.success("Agent deleted successfully!")
|
||||
st.session_state.fetch_agents()
|
||||
else:
|
||||
st.error("Failed to delete agent.")
|
|
@ -1,34 +1,45 @@
|
|||
import streamlit as st
|
||||
from core.ui.thread_management import display_thread_management
|
||||
from core.ui.message_display import display_messages
|
||||
from core.ui.thread_runner import display_thread_runner
|
||||
from core.ui.agent_management import display_agent_management
|
||||
from core.ui.message_display import display_messages_and_runner
|
||||
from core.ui.thread_runner import fetch_thread_runs, display_runs
|
||||
from core.ui.tool_display import display_tools
|
||||
from core.ui.utils import initialize_session_state, fetch_data
|
||||
from core.ui.utils import initialize_session_state, fetch_data, API_BASE_URL
|
||||
|
||||
def main():
|
||||
initialize_session_state()
|
||||
fetch_data()
|
||||
|
||||
st.set_page_config(page_title="AI Assistant Management System", layout="wide")
|
||||
|
||||
st.sidebar.title("Navigation")
|
||||
mode = st.sidebar.radio("Select Mode", ["Agent Management", "Thread Management", "Tools"])
|
||||
mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"])
|
||||
|
||||
st.title("AI Assistant Management System")
|
||||
|
||||
if mode == "Agent Management":
|
||||
display_agent_management()
|
||||
elif mode == "Tools":
|
||||
if mode == "Tools":
|
||||
display_tools()
|
||||
else: # Thread Management
|
||||
display_thread_management_content()
|
||||
|
||||
def display_thread_management_content():
|
||||
st.header("Thread Management")
|
||||
display_thread_management()
|
||||
col1, col2 = st.columns([1, 3])
|
||||
|
||||
with col1:
|
||||
display_thread_management()
|
||||
if st.session_state.selected_thread:
|
||||
display_thread_runner(st.session_state.selected_thread)
|
||||
|
||||
with col2:
|
||||
if st.session_state.selected_thread:
|
||||
display_messages_and_runner(st.session_state.selected_thread)
|
||||
|
||||
if st.session_state.selected_thread:
|
||||
display_messages(st.session_state.selected_thread)
|
||||
display_thread_runner(st.session_state.selected_thread)
|
||||
def display_thread_runner(thread_id):
|
||||
st.subheader("Thread Runs")
|
||||
|
||||
limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20)
|
||||
if st.button("Fetch Runs"):
|
||||
runs = fetch_thread_runs(thread_id, limit)
|
||||
display_runs(runs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,16 +1,21 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
import json
|
||||
from core.ui.utils import API_BASE_URL
|
||||
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||
from core.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content
|
||||
|
||||
def display_messages(thread_id):
|
||||
st.subheader(f"🧵 Thread ID: {thread_id}")
|
||||
# st.write("### 📝 Messages")
|
||||
def display_messages_and_runner(thread_id):
|
||||
st.subheader(f"Messages for Thread: {thread_id}")
|
||||
|
||||
messages = fetch_messages(thread_id)
|
||||
display_message_json(messages)
|
||||
display_message_list(messages)
|
||||
display_add_message_form(thread_id)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
display_add_message_form(thread_id)
|
||||
|
||||
with col2:
|
||||
display_run_thread_form(thread_id)
|
||||
|
||||
def fetch_messages(thread_id):
|
||||
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
|
||||
|
@ -20,28 +25,37 @@ def fetch_messages(thread_id):
|
|||
st.error("Failed to fetch messages.")
|
||||
return []
|
||||
|
||||
def display_message_json(messages):
|
||||
if st.button("Show/Hide JSON"):
|
||||
st.session_state.show_json = not st.session_state.get('show_json', False)
|
||||
|
||||
if st.session_state.get('show_json', False):
|
||||
json_str = json.dumps(messages, indent=2)
|
||||
st.code(json_str, language="json")
|
||||
|
||||
def display_message_list(messages):
|
||||
for msg in messages:
|
||||
with st.chat_message(msg['role']):
|
||||
st.write(msg['content'])
|
||||
|
||||
def display_add_message_form(thread_id):
|
||||
st.write("### ➕ Add a New Message")
|
||||
st.write("### Add a New Message")
|
||||
with st.form(key="add_message_form"):
|
||||
role = st.selectbox("🔹 Role", ["user", "assistant"], key="add_role")
|
||||
content = st.text_area("📝 Content", key="add_content")
|
||||
submitted = st.form_submit_button("➕ Add Message")
|
||||
role = st.selectbox("Role", ["user", "assistant"], key="add_role")
|
||||
content = st.text_area("Content", key="add_content")
|
||||
submitted = st.form_submit_button("Add Message")
|
||||
if submitted:
|
||||
add_message(thread_id, role, content)
|
||||
|
||||
def display_run_thread_form(thread_id):
|
||||
st.write("### Run Thread")
|
||||
with st.form(key="run_thread_form"):
|
||||
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
|
||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
|
||||
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
|
||||
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
|
||||
additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100)
|
||||
|
||||
tool_options = list(st.session_state.tools.keys())
|
||||
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
|
||||
|
||||
submitted = st.form_submit_button("Run Thread")
|
||||
if submitted:
|
||||
run_thread_data = prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools)
|
||||
run_thread(thread_id, run_thread_data)
|
||||
|
||||
def add_message(thread_id, role, content):
|
||||
message_data = {"role": role, "content": content}
|
||||
add_msg_response = requests.post(
|
||||
|
@ -49,6 +63,7 @@ def add_message(thread_id, role, content):
|
|||
json=message_data
|
||||
)
|
||||
if add_msg_response.status_code == 200:
|
||||
st.success("Message added successfully.")
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Failed to add message.")
|
|
@ -1,22 +1,22 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from core.ui.utils import API_BASE_URL
|
||||
from datetime import datetime
|
||||
|
||||
def display_thread_management():
|
||||
col1, col2 = st.columns([1, 2])
|
||||
st.subheader("Thread Management")
|
||||
|
||||
with col1:
|
||||
if st.button("➕ Create New Thread"):
|
||||
create_new_thread()
|
||||
if st.button("➕ Create New Thread", key="create_thread_button"):
|
||||
create_new_thread()
|
||||
|
||||
with col2:
|
||||
display_thread_selector()
|
||||
display_thread_selector()
|
||||
|
||||
def create_new_thread():
|
||||
response = requests.post(f"{API_BASE_URL}/threads/")
|
||||
if response.status_code == 200:
|
||||
thread_id = response.json()['thread_id']
|
||||
st.session_state.selected_thread = thread_id
|
||||
st.success(f"New thread created with ID: {thread_id}")
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Failed to create a new thread.")
|
||||
|
@ -25,17 +25,36 @@ def display_thread_selector():
|
|||
threads_response = requests.get(f"{API_BASE_URL}/threads/")
|
||||
if threads_response.status_code == 200:
|
||||
threads = threads_response.json()
|
||||
thread_options = [str(thread['thread_id']) for thread in threads]
|
||||
|
||||
# Sort threads by creation date if available, otherwise by thread_id
|
||||
def sort_key(thread):
|
||||
if 'creation_date' in thread:
|
||||
try:
|
||||
return datetime.strptime(thread['creation_date'], "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
st.warning(f"Invalid date format for thread {thread['thread_id']}")
|
||||
return thread['thread_id']
|
||||
|
||||
sorted_threads = sorted(threads, key=sort_key, reverse=True)
|
||||
|
||||
thread_options = []
|
||||
for thread in sorted_threads:
|
||||
if 'creation_date' in thread:
|
||||
thread_options.append(f"{thread['thread_id']} - Created: {thread['creation_date']}")
|
||||
else:
|
||||
thread_options.append(thread['thread_id'])
|
||||
|
||||
if st.session_state.selected_thread is None and threads:
|
||||
st.session_state.selected_thread = str(threads[0]['thread_id'])
|
||||
if st.session_state.selected_thread is None and sorted_threads:
|
||||
st.session_state.selected_thread = sorted_threads[0]['thread_id']
|
||||
|
||||
selected_thread = st.selectbox(
|
||||
"🔍 Select Thread",
|
||||
thread_options,
|
||||
key="thread_select",
|
||||
index=thread_options.index(str(st.session_state.selected_thread)) if st.session_state.selected_thread else 0
|
||||
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
|
||||
)
|
||||
|
||||
if selected_thread:
|
||||
st.session_state.selected_thread = int(selected_thread)
|
||||
st.session_state.selected_thread = selected_thread.split(' - ')[0] if ' - ' in selected_thread else selected_thread
|
||||
else:
|
||||
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")
|
|
@ -1,66 +1,16 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
import json
|
||||
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||
from core.ui.utils import API_BASE_URL
|
||||
from datetime import datetime
|
||||
|
||||
def display_thread_runner(thread_id):
|
||||
st.write("## ⚙️ Run Thread")
|
||||
|
||||
manual_config = display_manual_setup_tab()
|
||||
|
||||
# Common settings
|
||||
additional_instructions = st.text_area("Additional Instructions", key="additional_instructions", height=100)
|
||||
stream = st.checkbox("📡 Stream Responses", key="stream_responses")
|
||||
|
||||
# Prepare the run thread data
|
||||
run_thread_data = prepare_run_thread_data(manual_config, additional_instructions, stream)
|
||||
|
||||
# Display the preview of the request payload in an expander
|
||||
with st.expander("📤 Preview Request Payload", expanded=False):
|
||||
st.json(run_thread_data)
|
||||
|
||||
# Center the run button and make it more prominent
|
||||
col1, col2, col3 = st.columns([1, 2, 1])
|
||||
with col2:
|
||||
if st.button("▶️ Run Thread", key="run_thread_button", use_container_width=True):
|
||||
run_thread(thread_id, run_thread_data)
|
||||
|
||||
display_thread_run_status(thread_id)
|
||||
|
||||
def display_manual_setup_tab():
|
||||
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
|
||||
with col2:
|
||||
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
|
||||
|
||||
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
|
||||
|
||||
tool_options = list(st.session_state.tools.keys())
|
||||
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
|
||||
|
||||
manual_config = {
|
||||
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
|
||||
return {
|
||||
"system_message": {"role": "system", "content": system_message},
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"system_message": system_message,
|
||||
"selected_tools": selected_tools
|
||||
}
|
||||
|
||||
return manual_config
|
||||
|
||||
def prepare_run_thread_data(manual_config, additional_instructions, stream):
|
||||
tools = [st.session_state.tools[tool]['schema'] for tool in manual_config['selected_tools'] if tool in st.session_state.tools]
|
||||
return {
|
||||
"system_message": {"role": "system", "content": manual_config['system_message']},
|
||||
"model_name": manual_config['model_name'],
|
||||
"temperature": manual_config['temperature'],
|
||||
"max_tokens": manual_config['max_tokens'],
|
||||
"tools": tools,
|
||||
"additional_instructions": additional_instructions,
|
||||
"stream": stream
|
||||
"tools": selected_tools,
|
||||
"additional_system_message": additional_system_message
|
||||
}
|
||||
|
||||
def run_thread(thread_id, run_thread_data):
|
||||
|
@ -73,49 +23,77 @@ def run_thread(thread_id, run_thread_data):
|
|||
response_data = run_thread_response.json()
|
||||
st.success("Thread run completed successfully!")
|
||||
|
||||
# Display the return payload in an expander
|
||||
with st.expander("📥 Return Payload", expanded=False):
|
||||
st.json(response_data)
|
||||
if 'id' in response_data:
|
||||
st.session_state.latest_run_id = response_data['id']
|
||||
|
||||
# Display the actual response content
|
||||
st.write("### 📬 Response Content")
|
||||
st.subheader("Response Content")
|
||||
display_response_content(response_data)
|
||||
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Failed to run thread.")
|
||||
with st.expander("❌ Error Response", expanded=True):
|
||||
st.json(run_thread_response.json())
|
||||
st.error(f"Failed to run thread. Status code: {run_thread_response.status_code}")
|
||||
st.text("Response content:")
|
||||
st.text(run_thread_response.text)
|
||||
|
||||
def display_response_content(response_data):
|
||||
if isinstance(response_data, dict) and 'response' in response_data:
|
||||
for item in response_data['response']:
|
||||
if isinstance(item, dict):
|
||||
if 'content' in item:
|
||||
st.markdown(item['content'])
|
||||
elif 'tool_calls' in item:
|
||||
st.write("**Tool Calls:**")
|
||||
for tool_call in item['tool_calls']:
|
||||
st.write(f"- Function: `{tool_call['function']['name']}`")
|
||||
st.code(tool_call['function']['arguments'], language="json")
|
||||
elif isinstance(item, str):
|
||||
st.markdown(item)
|
||||
if isinstance(response_data, dict) and 'choices' in response_data:
|
||||
message = response_data['choices'][0]['message']
|
||||
st.write(f"**Role:** {message['role']}")
|
||||
st.write(f"**Content:** {message['content']}")
|
||||
|
||||
if 'tool_calls' in message:
|
||||
st.write("**Tool Calls:**")
|
||||
for tool_call in message['tool_calls']:
|
||||
st.write(f"- Function: `{tool_call['function']['name']}`")
|
||||
st.code(tool_call['function']['arguments'], language="json")
|
||||
else:
|
||||
st.json(response_data)
|
||||
|
||||
def display_thread_run_status(thread_id):
|
||||
status_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/run/status/")
|
||||
if status_response.status_code == 200:
|
||||
status_data = status_response.json()
|
||||
st.write("### ⚙️ Thread Run Status")
|
||||
status = status_data.get('status')
|
||||
if status == 'completed':
|
||||
st.success(f"**Status:** {status}")
|
||||
elif status == 'error':
|
||||
st.error(f"**Status:** {status}")
|
||||
with st.expander("Error Details", expanded=True):
|
||||
st.code(status_data.get('error_message'), language="")
|
||||
else:
|
||||
st.info(f"**Status:** {status}")
|
||||
def fetch_thread_runs(thread_id, limit):
|
||||
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
st.warning("Could not retrieve thread run status.")
|
||||
st.error("Failed to retrieve runs.")
|
||||
return []
|
||||
|
||||
def format_timestamp(timestamp):
|
||||
if timestamp:
|
||||
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||
return 'N/A'
|
||||
|
||||
def display_runs(runs):
|
||||
for run in runs:
|
||||
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
|
||||
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
|
||||
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
|
||||
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
|
||||
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
|
||||
with col2:
|
||||
st.write(f"**Model:** {run['model']}")
|
||||
st.write(f"**Temperature:** {run['temperature']}")
|
||||
st.write(f"**Top P:** {run['top_p']}")
|
||||
st.write(f"**Max Tokens:** {run['max_tokens']}")
|
||||
st.write(f"**Tool Choice:** {run['tool_choice']}")
|
||||
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
|
||||
|
||||
st.write("**System Message:**")
|
||||
st.json(run['system_message'])
|
||||
|
||||
if run['tools']:
|
||||
st.write("**Tools:**")
|
||||
st.json(run['tools'])
|
||||
|
||||
if run['usage']:
|
||||
st.write("**Usage:**")
|
||||
st.json(run['usage'])
|
||||
|
||||
if run['response_format']:
|
||||
st.write("**Response Format:**")
|
||||
st.json(run['response_format'])
|
||||
|
||||
if run['last_error']:
|
||||
st.error("**Last Error:**")
|
||||
st.code(run['last_error'])
|
|
@ -4,13 +4,6 @@ from core.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
|||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
|
||||
def fetch_agents():
|
||||
response = requests.get(f"{API_BASE_URL}/agents/")
|
||||
if response.status_code == 200:
|
||||
st.session_state.agents = response.json()
|
||||
else:
|
||||
st.error("Failed to fetch agents.")
|
||||
|
||||
def fetch_tools():
|
||||
response = requests.get(f"{API_BASE_URL}/tools/")
|
||||
if response.status_code == 200:
|
||||
|
@ -21,15 +14,10 @@ def fetch_tools():
|
|||
def initialize_session_state():
|
||||
if 'selected_thread' not in st.session_state:
|
||||
st.session_state.selected_thread = None
|
||||
if 'agents' not in st.session_state:
|
||||
st.session_state.agents = []
|
||||
if 'tools' not in st.session_state:
|
||||
st.session_state.tools = []
|
||||
if 'fetch_agents' not in st.session_state:
|
||||
st.session_state.fetch_agents = fetch_agents
|
||||
if 'fetch_tools' not in st.session_state:
|
||||
st.session_state.fetch_tools = fetch_tools
|
||||
|
||||
def fetch_data():
|
||||
fetch_agents()
|
||||
fetch_tools()
|
|
@ -417,6 +417,26 @@ files = [
|
|||
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.115.0"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fastapi-0.115.0-py3-none-any.whl", hash = "sha256:17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631"},
|
||||
{file = "fastapi-0.115.0.tar.gz", hash = "sha256:f93b4ca3529a8ebc6fc3fcf710e5efa8de3df9b41570958abf1d97d843138004"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||
starlette = ">=0.37.2,<0.39.0"
|
||||
typing-extensions = ">=4.8.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.16.1"
|
||||
|
@ -2284,6 +2304,53 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
|||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3_binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "2.1.3"
|
||||
description = "SSE plugin for Starlette"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "sse_starlette-2.1.3-py3-none-any.whl", hash = "sha256:8ec846438b4665b9e8c560fcdea6bc8081a3abf7942faa95e5a744999d219772"},
|
||||
{file = "sse_starlette-2.1.3.tar.gz", hash = "sha256:9cd27eb35319e1414e3d2558ee7414487f9529ce3b3cf9b21434fd110e017169"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
starlette = "*"
|
||||
uvicorn = "*"
|
||||
|
||||
[package.extras]
|
||||
examples = ["fastapi"]
|
||||
|
||||
[[package]]
|
||||
name = "sseclient-py"
|
||||
version = "1.7.2"
|
||||
description = "SSE client for Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "sseclient-py-1.7.2.tar.gz", hash = "sha256:ba3197d314766eccb72a1dda80b5fa14a0fbba07d796a287654c07edde88fe0f"},
|
||||
{file = "sseclient_py-1.7.2-py2.py3-none-any.whl", hash = "sha256:a758653b13b78df42cdb696740635a26cb72ad433b75efb68dbbb163d099b6a9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.38.6"
|
||||
description = "The little ASGI library that shines."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05"},
|
||||
{file = "starlette-0.38.6.tar.gz", hash = "sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.4.0,<5"
|
||||
|
||||
[package.extras]
|
||||
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "streamlit"
|
||||
version = "1.39.0"
|
||||
|
@ -2615,6 +2682,24 @@ h2 = ["h2 (>=4,<5)"]
|
|||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.31.0"
|
||||
description = "The lightning-fast ASGI server."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"},
|
||||
{file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7.0"
|
||||
h11 = ">=0.8"
|
||||
|
||||
[package.extras]
|
||||
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "watchdog"
|
||||
version = "5.0.3"
|
||||
|
@ -2784,4 +2869,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "817f3d510c67d6a59033e26f11a247490013e3e63c6184844d6b261f43352cfd"
|
||||
content-hash = "11cd10192810fd94515b1f6ad5d8ce27b5d665d1d4e65fff9b05b4cd80c9d377"
|
||||
|
|
|
@ -21,6 +21,9 @@ litellm = "^1.44.4"
|
|||
pytest = "^8.3.2"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
agentops = "^0.3.10"
|
||||
sseclient-py = "1.7.2"
|
||||
fastapi = "^0.115.0"
|
||||
sse-starlette = "^2.1.3"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from core.tools.tool import Tool, ToolResult
|
||||
from core.tool import Tool, ToolResult
|
||||
from core.config import settings
|
||||
|
||||
class FilesTool(Tool):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, Any
|
||||
from core.tools.tool import Tool, ToolResult
|
||||
from core.tool import Tool, ToolResult
|
||||
|
||||
class ExampleTool(Tool):
|
||||
description = "An example tool for demonstration purposes."
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Hello there! This is a sample file created just for you.
|
|
@ -0,0 +1 @@
|
|||
This file contains random information about this assistant. I'm here to help you with a wide range of tasks, from answering questions to managing files.
|
|
@ -1 +0,0 @@
|
|||
This is some random content for the file.
|
|
@ -1 +0,0 @@
|
|||
random contents
|
|
@ -0,0 +1,4 @@
|
|||
Here are some random tips:
|
||||
1. Stay curious and keep learning.
|
||||
2. Organize your workspace for better productivity.
|
||||
3. Take breaks to maintain focus and energy.
|
Loading…
Reference in New Issue