This commit is contained in:
marko-kraemer 2024-10-07 21:13:11 +02:00
parent a933193851
commit 6c903fa761
24 changed files with 702 additions and 559 deletions

View File

@ -1,58 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from core.db import Agent
from datetime import datetime
from typing import List, Optional
import json
class AgentManager:
def __init__(self, db):
self.db = db
async def create_agent(self, model: str, name: str, system_prompt: str, selected_tools: List[str], temperature: float = 0.5) -> int:
async with self.db.get_async_session() as session:
new_agent = Agent(
model=model,
name=name,
system_prompt=system_prompt,
selected_tools=selected_tools, # Store as a list directly
temperature=temperature,
created_at=datetime.now().isoformat()
)
session.add(new_agent)
await session.commit()
await session.refresh(new_agent)
return new_agent.id
async def get_agent(self, agent_id: int) -> Optional[Agent]:
async with self.db.get_async_session() as session:
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
agent = result.scalar_one_or_none()
return agent
async def update_agent(self, agent_id: int, **kwargs) -> bool:
async with self.db.get_async_session() as session:
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if agent:
for key, value in kwargs.items():
setattr(agent, key, value)
await session.commit()
return True
return False
async def delete_agent(self, agent_id: int) -> bool:
async with self.db.get_async_session() as session:
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if agent:
await session.delete(agent)
await session.commit()
return True
return False
async def list_agents(self) -> List[Agent]:
async with self.db.get_async_session() as session:
result = await session.execute(select(Agent))
agents = result.scalars().all()
return agents

View File

@ -1,109 +1,107 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Query, Path
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import asyncio
from fastapi.responses import StreamingResponse
from core.db import Database
from core.thread_manager import ThreadManager
from core.agent_manager import AgentManager
from core.tools.tool_registry import ToolRegistry
from core.tool_registry import ToolRegistry
from core.config import Settings
app = FastAPI()
app = FastAPI(
title="Thread Manager API",
description="API for managing and running threads with LLM integration",
version="1.0.0",
)
db = Database()
manager = ThreadManager(db)
tool_registry = ToolRegistry() # Initialize here
agent_manager = AgentManager(db)
tool_registry = ToolRegistry()
# Pydantic models for request and response bodies
class Message(BaseModel):
role: str
content: str
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
content: str = Field(..., description="The content of the message")
class RunThreadRequest(BaseModel):
system_message: Dict[str, Any]
model_name: str
temperature: float = 0.5
max_tokens: Optional[int] = 500
tools: Optional[List[str]] = None
tool_choice: str = "required"
additional_instructions: Optional[str] = None
stream: bool = False
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
model_name: str = Field(..., description="The name of the LLM model to be used")
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended")
stream: bool = Field(False, description="Whether to stream the response")
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
class Agent(BaseModel):
name: str
model: str
system_prompt: str
selected_tools: List[str]
temperature: float = 0.5
@app.post("/threads/")
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
async def create_thread():
"""
Create a new thread and return its ID.
"""
thread_id = await manager.create_thread()
return {"thread_id": thread_id}
@app.get("/threads/")
@app.get("/threads/", response_model=List[Dict[str, str]], summary="Get all threads")
async def get_threads():
"""
Retrieve a list of all thread IDs.
"""
threads = await manager.get_threads()
return [{"thread_id": thread.thread_id} for thread in threads]
@app.post("/threads/{thread_id}/messages/")
async def add_message(thread_id: int, message: Message):
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
async def add_message(thread_id: str, message: Message):
"""
Add a new message to the specified thread.
"""
await manager.add_message(thread_id, message.dict())
return {"status": "success"}
@app.get("/threads/{thread_id}/messages/")
async def list_messages(thread_id: int):
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
async def list_messages(thread_id: str):
"""
Retrieve all messages from the specified thread.
"""
messages = await manager.list_messages(thread_id)
return messages
@app.post("/threads/{thread_id}/run/")
async def run_thread(thread_id: int, request: Dict[str, Any]):
if 'agent_id' in request:
# Agent-based run
response_gen = manager.run_thread(
thread_id=thread_id,
agent_id=request['agent_id'],
additional_instructions=request.get('additional_instructions'),
stream=request.get('stream', False)
)
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
async def run_thread(thread_id: str, request: RunThreadRequest):
"""
Run the specified thread with the given parameters.
"""
response = await manager.run_thread(
thread_id=thread_id,
system_message=request.system_message,
model_name=request.model_name,
temperature=request.temperature,
max_tokens=request.max_tokens,
tools=request.tools,
additional_system_message=request.additional_system_message,
top_p=request.top_p,
tool_choice=request.tool_choice,
response_format=request.response_format)
return {"status": "success", "response": response}
@app.get("/threads/{thread_id}/run/status/", response_model=Dict[str, Any], summary="Get thread run status")
async def get_thread_run_status(thread_id: str):
"""
Retrieve the status of the latest run for the specified thread.
"""
latest_thread_run = await manager.get_latest_thread_run(thread_id)
if latest_thread_run:
return latest_thread_run
else:
# Manual configuration run
response_gen = manager.run_thread(
thread_id=thread_id,
system_message=request['system_message'],
model_name=request['model_name'],
temperature=request['temperature'],
max_tokens=request['max_tokens'],
tools=request['tools'],
additional_instructions=request.get('additional_instructions'),
stream=request.get('stream', False)
)
return {"status": "No runs found for this thread."}
if request.get('stream', False):
raise HTTPException(status_code=501, detail="Streaming is not supported via this endpoint.")
else:
response = []
async for chunk in response_gen:
response.append(chunk)
return {"response": response}
@app.get("/threads/{thread_id}/run/status/")
async def get_thread_run_status(thread_id: int):
try:
latest_thread_run = await manager.get_latest_thread_run(thread_id)
if latest_thread_run:
return {
"status": latest_thread_run.status,
"error_message": latest_thread_run.error_message
}
else:
return {"status": "No runs found for this thread."}
except AttributeError:
return {"status": "Error", "message": "Unable to retrieve thread run status."}
@app.get("/tools/")
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
async def get_tools():
"""
Retrieve a list of all available tools and their schemas.
"""
tools = tool_registry.get_all_tools()
if not tools:
print("No tools found in the registry") # Debug print
@ -116,52 +114,45 @@ async def get_tools():
for name, tool_info in tools.items()
}
@app.post("/agents/")
async def create_agent(agent: Agent):
agent_id = await agent_manager.create_agent(**agent.dict())
return {"agent_id": agent_id}
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
async def get_run(
thread_id: str = Path(..., description="The ID of the thread that was run"),
run_id: str = Path(..., description="The ID of the run to retrieve")
):
"""
Retrieve the run object matching the specified ID.
"""
run = await manager.get_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/agents/")
async def list_agents():
agents = await agent_manager.list_agents()
return [
{
"id": agent.id,
"name": agent.name,
"model": agent.model,
"system_prompt": agent.system_prompt,
"selected_tools": agent.selected_tools,
"temperature": agent.temperature,
"created_at": agent.created_at
}
for agent in agents
]
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
async def cancel_run(
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
run_id: str = Path(..., description="The ID of the run to cancel")
):
"""
Cancels a run that is in_progress.
"""
run = await manager.cancel_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/agents/{agent_id}")
async def get_agent(agent_id: int):
agent = await agent_manager.get_agent(agent_id)
if agent:
return {
"id": agent.id,
"name": agent.name,
"model": agent.model,
"system_prompt": agent.system_prompt,
"selected_tools": agent.selected_tools,
"temperature": agent.temperature,
"created_at": agent.created_at
}
raise HTTPException(status_code=404, detail="Agent not found")
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
async def list_runs(
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
):
"""
Returns a list of runs belonging to a thread.
"""
runs = await manager.list_runs(thread_id, limit)
return runs
@app.put("/agents/{agent_id}")
async def update_agent(agent_id: int, agent: Agent):
success = await agent_manager.update_agent(agent_id, **agent.dict())
if success:
return {"status": "success"}
raise HTTPException(status_code=404, detail="Agent not found")
# Add more endpoints as needed for production use
@app.delete("/agents/{agent_id}")
async def delete_agent(agent_id: int):
success = await agent_manager.delete_agent(agent_id)
if success:
return {"status": "success"}
raise HTTPException(status_code=404, detail="Agent not found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -1,17 +1,19 @@
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON
from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON, Boolean
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from core.config import settings # Changed from Settings to settings
import os
from contextlib import asynccontextmanager
import uuid
from datetime import datetime
Base = declarative_base()
class Thread(Base):
__tablename__ = 'threads'
thread_id = Column(Integer, primary_key=True)
thread_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
messages = Column(Text)
creation_date = Column(String)
last_updated_date = Column(String)
@ -22,25 +24,31 @@ class Thread(Base):
class ThreadRun(Base):
__tablename__ = 'thread_runs'
run_id = Column(Integer, primary_key=True)
thread_id = Column(Integer, ForeignKey('threads.thread_id'))
messages = Column(Text)
creation_date = Column(String)
status = Column(String)
error_message = Column(Text, nullable=True)
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
thread_id = Column(String(36), ForeignKey('threads.thread_id'))
created_at = Column(Integer)
status = Column(String)
last_error = Column(Text, nullable=True)
started_at = Column(Integer, nullable=True)
cancelled_at = Column(Integer, nullable=True)
failed_at = Column(Integer, nullable=True)
completed_at = Column(Integer, nullable=True)
model = Column(String)
system_message = Column(Text)
tools = Column(JSON, nullable=True)
usage = Column(JSON, nullable=True)
temperature = Column(Float, nullable=True)
top_p = Column(Float, nullable=True)
max_tokens = Column(Integer, nullable=True)
tool_choice = Column(String, nullable=True)
execute_tools_async = Column(Boolean)
response_format = Column(JSON, nullable=True)
thread = relationship("Thread", back_populates="thread_runs")
class Agent(Base):
__tablename__ = 'agents'
id = Column(Integer, primary_key=True)
model = Column(String, nullable=False)
name = Column(String, nullable=False)
system_prompt = Column(Text, nullable=False)
selected_tools = Column(JSON) # Changed from ARRAY to JSON
temperature = Column(Float, default=0.5)
created_at = Column(String, nullable=False)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.created_at = int(datetime.utcnow().timestamp())
# class MemoryModule(Base):
# __tablename__ = 'memory_modules'

View File

@ -1,4 +1,4 @@
from typing import Union, AsyncGenerator
from typing import Union, Dict, Any
import litellm
import os
import json
@ -26,24 +26,13 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False) -> AsyncGenerator[Union[dict, str], None]:
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
litellm.set_verbose = True
async def attempt_api_call(api_call_func, max_attempts=3):
for attempt in range(max_attempts):
try:
response = await api_call_func()
if stream:
async for chunk in response:
yield chunk
else:
response_content = response.choices[0].message['content'] if json_mode else response
if json_mode:
if not json.loads(response_content):
logger.info(f"Invalid JSON received, retrying attempt {attempt + 1}")
continue
yield response
return
return await api_call_func()
except litellm.exceptions.RateLimitError as e:
logger.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
await asyncio.sleep(30)
@ -60,7 +49,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
"model": model_name,
"messages": messages,
"temperature": temperature,
"response_format": {"type": "json_object"} if json_mode else None,
"response_format": response_format or ({"type": "json_object"} if json_mode else None),
"top_p": top_p,
"stream": stream,
}
# Add api_key and api_base if provided
@ -88,12 +79,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
api_call_params["tool_choice"] = tool_choice
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
# if messages[0]["role"] != "user":
# api_call_params["messages"] = [{"role": "user", "content": "."}] + messages
api_call_params["extra_headers"] = {
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
}
# # "anthropic-beta": "prompt-caching-2024-07-31"
# Log the API request
logger.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
@ -101,20 +89,49 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
if agentops_session:
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
else:
if stream:
response = await litellm.acompletion(**api_call_params, stream=True)
else:
response = await litellm.acompletion(**api_call_params)
response = await litellm.acompletion(**api_call_params)
# Log the API response
logger.info(f"Received API response: {response}")
return response
async for result in attempt_api_call(api_call):
yield result
return await attempt_api_call(api_call)
# Sample Usage
if __name__ == "__main__":
import asyncio
async def test_llm_api_call(stream=True):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Complex essay on economics"}
]
model_name = "gpt-4o"
pass
response = await make_llm_api_call(messages, model_name, stream=stream)
if stream:
print("Streaming response:")
async for chunk in response:
if isinstance(chunk, dict) and 'choices' in chunk:
content = chunk['choices'][0]['delta'].get('content', '')
print(content, end='', flush=True)
else:
# For non-dict responses (like ModelResponse objects)
content = chunk.choices[0].delta.content
if content:
print(content, end='', flush=True)
print("\nStream completed.")
else:
print("Non-streaming response:")
if isinstance(response, dict) and 'choices' in response:
print(response['choices'][0]['message']['content'])
else:
# For non-dict responses (like ModelResponse objects)
print(response.choices[0].message.content)
# Example usage:
# asyncio.run(test_llm_api_call(stream=True)) # For streaming
# asyncio.run(test_llm_api_call(stream=False)) # For non-streaming
asyncio.run(test_llm_api_call())

View File

@ -5,19 +5,18 @@ from typing import List, Dict, Any, Optional, Callable, AsyncGenerator, Union
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import Database, Thread, ThreadRun
from core.tools.tool import Tool, ToolResult
from core.tool import Tool, ToolResult
from core.llm import make_llm_api_call
# from core.working_memory_manager import WorkingMemory
from datetime import datetime
from core.tools.tool_registry import ToolRegistry
from core.tool_registry import ToolRegistry
import re
from core.agent_manager import AgentManager
import uuid
class ThreadManager:
def __init__(self, db: Database):
self.db = db
self.tool_registry = ToolRegistry()
self.agent_manager = AgentManager(db)
async def create_thread(self) -> int:
async with self.db.get_async_session() as session:
@ -197,84 +196,176 @@ class ThreadManager:
async def run_thread(
self,
thread_id: int,
agent_id: Optional[int] = None,
system_message: Optional[Dict[str, Any]] = None,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
thread_id: str,
system_message: Dict[str, Any],
model_name: str,
temperature: float = 0.5,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
additional_instructions: Optional[str] = None,
tools: Optional[List[str]] = None,
additional_system_message: Optional[str] = None,
hide_tool_msgs: bool = False,
execute_tools_async: bool = True,
use_tool_parser: bool = False,
stream: bool = False
) -> AsyncGenerator[Union[Dict[str, Any], str], None]:
if agent_id is not None:
agent = await self.agent_manager.get_agent(agent_id)
if not agent:
raise ValueError(f"Agent with id {agent_id} not found")
system_message = {"role": "system", "content": agent.system_prompt}
model_name = agent.model
temperature = agent.temperature
tools = [self.tool_registry.get_tool(tool).schema()[0] for tool in agent.selected_tools] if agent.selected_tools else None
elif system_message is None or model_name is None:
raise ValueError("Either agent_id or system_message and model_name must be provided")
top_p: Optional[float] = None,
tool_choice: str = "auto",
response_format: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
run_id = str(uuid.uuid4())
# Fetch full tool objects based on the provided tool names
full_tools = None
if tools:
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in tools if self.tool_registry.get_tool(tool_name)]
thread_run = ThreadRun(
id=run_id,
thread_id=thread_id,
status="queued",
model=model_name,
system_message=json.dumps(system_message),
tools=json.dumps(full_tools) if full_tools else None,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
tool_choice=tool_choice,
execute_tools_async=execute_tools_async,
response_format=json.dumps(response_format) if response_format else None
)
if await self.should_stop(thread_id):
yield {"status": "stopped", "message": "Session cancelled"}
return
async with self.db.get_async_session() as session:
session.add(thread_run)
await session.commit()
if use_tool_parser:
hide_tool_msgs = True
await self.cleanup_incomplete_tool_calls(thread_id)
# Prepare messages
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
prepared_messages = [system_message] + messages
if additional_instructions:
additional_instruction_message = {
"role": "user",
"content": additional_instructions
}
prepared_messages.append(additional_instruction_message)
try:
response_stream = make_llm_api_call(
thread_run.status = "in_progress"
thread_run.started_at = int(datetime.utcnow().timestamp())
await self.update_thread_run(thread_run)
if await self.should_stop(thread_id):
thread_run.status = "cancelled"
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
await self.update_thread_run(thread_run)
return {"status": "stopped", "message": "Session cancelled"}
if use_tool_parser:
hide_tool_msgs = True
await self.cleanup_incomplete_tool_calls(thread_id)
# Prepare messages
messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs)
prepared_messages = [system_message] + messages
if additional_system_message:
additional_instruction_message = {
"role": "user",
"content": additional_system_message
}
prepared_messages.append(additional_instruction_message)
response = await make_llm_api_call(
prepared_messages,
model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
stream=stream
tools=full_tools,
tool_choice=tool_choice,
stream=False,
top_p=top_p,
response_format=response_format
)
async for partial_response in response_stream:
if stream:
yield partial_response
else:
response = partial_response
usage = response.usage if hasattr(response, 'usage') else None
usage_dict = self.serialize_usage(usage) if usage else None
thread_run.usage = usage_dict
if not stream:
if tools is None or use_tool_parser:
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
else:
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
# Add the assistant's message to the thread
assistant_message = {
"role": "assistant",
"content": response.choices[0].message['content']
}
if 'tool_calls' in response.choices[0].message:
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
await self.add_message(thread_id, assistant_message)
if await self.should_stop(thread_id):
yield {"status": "stopped", "message": "Session cancelled"}
else:
await self.save_thread_run(thread_id)
yield response
if tools is None or use_tool_parser:
await self.handle_response_without_tools(thread_id, response, use_tool_parser)
else:
await self.handle_response_with_tools(thread_id, response, execute_tools_async)
thread_run.status = "completed"
thread_run.completed_at = int(datetime.utcnow().timestamp())
await self.update_thread_run(thread_run)
return {
"id": thread_run.id,
"choices": [self.serialize_choice(choice) for choice in response.choices],
"usage": usage_dict,
"model": model_name,
"object": "chat.completion",
"created": int(datetime.utcnow().timestamp())
}
except Exception as e:
error_message = f"Error in API call: {str(e)}\n\nFull error: {repr(e)}"
logging.error(error_message)
await self.update_thread_run_with_error(thread_id, error_message)
yield {"status": "error", "message": error_message}
thread_run.status = "failed"
thread_run.failed_at = int(datetime.utcnow().timestamp())
thread_run.last_error = error_message
await self.update_thread_run(thread_run)
return {"status": "error", "message": error_message}
def serialize_usage(self, usage):
return {
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
}
def serialize_completion_tokens_details(self, details):
return {
"audio_tokens": details.audio_tokens,
"reasoning_tokens": details.reasoning_tokens
}
def serialize_prompt_tokens_details(self, details):
return {
"audio_tokens": details.audio_tokens,
"cached_tokens": details.cached_tokens
}
def serialize_choice(self, choice):
return {
"finish_reason": choice.finish_reason,
"index": choice.index,
"message": self.serialize_message(choice.message)
}
def serialize_message(self, message):
return {
"content": message.content,
"role": message.role,
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
}
def serialize_tool_call(self, tool_call):
return {
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
async def update_thread_run(self, thread_run: ThreadRun):
async with self.db.get_async_session() as session:
session.add(thread_run)
await session.commit()
await session.refresh(thread_run)
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
response_content = response.choices[0].message['content']
@ -282,8 +373,8 @@ class ThreadManager:
if use_tool_parser:
await self.handle_tool_parser_response(thread_id, response_content)
else:
logging.info("Adding assistant message to thread.")
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
# The message has already been added in the run_thread method, so we don't need to add it again here
pass
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
@ -325,9 +416,8 @@ class ThreadManager:
response_message = response.choices[0].message
tool_calls = response_message.get('tool_calls', [])
assistant_message = self.create_assistant_message_with_tools(response_message)
await self.add_message(thread_id, assistant_message)
# The assistant message has already been added in the run_thread method
available_functions = self.get_available_functions()
if await self.should_stop(thread_id):
@ -348,9 +438,7 @@ class ThreadManager:
except AttributeError as e:
logging.error(f"AttributeError: {e}")
content = response_message.get('content', '')
if content:
await self.add_message(thread_id, {"role": "assistant", "content": content})
# No need to add the message here as it's already been added in the run_thread method
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
message = {
@ -452,7 +540,7 @@ class ThreadManager:
result = await session.execute(stmt)
return result.scalar_one_or_none() is not None
async def save_thread_run(self, thread_id: int):
async def save_thread_run(self, thread_id: str):
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
@ -461,14 +549,26 @@ class ThreadManager:
messages = json.loads(thread.messages)
creation_date = datetime.now().isoformat()
new_thread_run = ThreadRun(
thread_id=thread_id,
messages=json.dumps(messages),
creation_date=creation_date,
status='completed'
)
session.add(new_thread_run)
await session.commit()
# Get the latest ThreadRun for this thread
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
result = await session.execute(stmt)
latest_thread_run = result.scalar_one_or_none()
if latest_thread_run:
# Update the existing ThreadRun
latest_thread_run.messages = json.dumps(messages)
latest_thread_run.last_updated_date = creation_date
await session.commit()
else:
# Create a new ThreadRun if none exists
new_thread_run = ThreadRun(
thread_id=thread_id,
messages=json.dumps(messages),
creation_date=creation_date,
status='completed'
)
session.add(new_thread_run)
await session.commit()
async def get_thread(self, thread_id: int) -> Optional[Thread]:
async with self.db.get_async_session() as session:
@ -489,11 +589,92 @@ class ThreadManager:
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
return result.scalars().all()
async def get_latest_thread_run(self, thread_id: int):
async def get_latest_thread_run(self, thread_id: str):
async with self.db.get_async_session() as session:
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
result = await session.execute(stmt)
return result.scalar_one_or_none()
latest_run = result.scalar_one_or_none()
if latest_run:
return {
"id": latest_run.id,
"status": latest_run.status,
"error_message": latest_run.last_error,
"created_at": latest_run.created_at,
"started_at": latest_run.started_at,
"completed_at": latest_run.completed_at,
"cancelled_at": latest_run.cancelled_at,
"failed_at": latest_run.failed_at,
"model": latest_run.model,
"usage": latest_run.usage
}
return None
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id:
return {
"id": run.id,
"thread_id": run.thread_id,
"status": run.status,
"created_at": run.created_at,
"started_at": run.started_at,
"completed_at": run.completed_at,
"cancelled_at": run.cancelled_at,
"failed_at": run.failed_at,
"model": run.model,
"system_message": json.loads(run.system_message) if run.system_message else None,
"tools": json.loads(run.tools) if run.tools else None,
"usage": run.usage,
"temperature": run.temperature,
"top_p": run.top_p,
"max_tokens": run.max_tokens,
"tool_choice": run.tool_choice,
"execute_tools_async": run.execute_tools_async,
"response_format": json.loads(run.response_format) if run.response_format else None,
"last_error": run.last_error
}
return None
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id and run.status == "in_progress":
run.status = "cancelled"
run.cancelled_at = int(datetime.utcnow().timestamp())
await session.commit()
return await self.get_run(thread_id, run_id)
return None
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
async with self.db.get_async_session() as session:
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
result = await session.execute(stmt)
runs = result.scalars().all()
return [
{
"id": run.id,
"thread_id": run.thread_id,
"status": run.status,
"created_at": run.created_at,
"started_at": run.started_at,
"completed_at": run.completed_at,
"cancelled_at": run.cancelled_at,
"failed_at": run.failed_at,
"model": run.model,
"system_message": json.loads(run.system_message) if run.system_message else None,
"tools": json.loads(run.tools) if run.tools else None,
"usage": run.usage,
"temperature": run.temperature,
"top_p": run.top_p,
"max_tokens": run.max_tokens,
"tool_choice": run.tool_choice,
"execute_tools_async": run.execute_tools_async,
"response_format": json.loads(run.response_format) if run.response_format else None,
"last_error": run.last_error
}
for run in runs
]
if __name__ == "__main__":
import asyncio
@ -519,6 +700,4 @@ if __name__ == "__main__":
print(f"Response: {response}")
asyncio.run(main())
asyncio.run(main())

View File

@ -1,5 +1,5 @@
from typing import Dict, Type, Any
from core.tools.tool import Tool
from typing import Dict, Type, Any, Optional
from core.tool import Tool
from core.config import settings
import importlib.util
import os
@ -39,7 +39,7 @@ class ToolRegistry:
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
def get_tool(self, tool_name: str) -> Dict[str, Any]:
def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
return self.tools.get(tool_name)
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:

View File

@ -1,97 +0,0 @@
import streamlit as st
import requests
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
def display_agent_management():
st.header("Agent Management")
col1, col2 = st.columns([1, 1])
with col1:
display_create_agent_form()
with col2:
display_existing_agents()
def display_create_agent_form():
st.subheader("Create New Agent")
with st.form("create_agent_form"):
new_agent_name = st.text_input("Agent Name")
new_agent_model = st.selectbox("Model", AI_MODELS)
new_agent_system_prompt = st.text_area("System Prompt", value=STANDARD_SYSTEM_MESSAGE)
new_agent_temperature = st.slider("Temperature", 0.0, 1.0, 0.5)
tool_options = list(st.session_state.tools.keys())
new_agent_tools = st.multiselect("Tools", tool_options)
submitted = st.form_submit_button("Create Agent")
if submitted:
create_agent(new_agent_name, new_agent_model, new_agent_system_prompt, new_agent_temperature, new_agent_tools)
def create_agent(name, model, system_prompt, temperature, tools):
response = requests.post(f"{API_BASE_URL}/agents/", json={
"name": name,
"model": model,
"system_prompt": system_prompt,
"temperature": temperature,
"selected_tools": tools
})
if response.status_code == 200:
st.success("Agent created successfully!")
st.session_state.fetch_agents()
else:
st.error("Failed to create agent.")
def display_existing_agents():
st.subheader("Existing Agents")
for agent in st.session_state.agents:
with st.expander(f"Agent: {agent['name']}"):
display_agent_details(agent)
def display_agent_details(agent):
st.write(f"Model: {agent['model']}")
st.write(f"Temperature: {agent['temperature']}")
st.write(f"Tools: {', '.join(agent['selected_tools'])}")
if st.button(f"Edit Agent {agent['id']}"):
st.session_state.editing_agent = agent['id']
if st.button(f"Delete Agent {agent['id']}"):
delete_agent(agent['id'])
if st.session_state.get('editing_agent') == agent['id']:
edit_agent_form(agent)
def edit_agent_form(agent):
with st.form(f"edit_agent_form_{agent['id']}"):
updated_name = st.text_input("Agent Name", value=agent['name'])
updated_model = st.selectbox("Model", AI_MODELS, index=AI_MODELS.index(agent['model']))
updated_system_prompt = st.text_area("System Prompt", value=agent['system_prompt'])
updated_temperature = st.slider("Temperature", 0.0, 1.0, value=agent['temperature'])
tool_options = list(st.session_state.tools.keys())
updated_tools = st.multiselect("Tools", options=tool_options, default=agent['selected_tools'])
if st.form_submit_button("Update Agent"):
update_agent(agent['id'], updated_name, updated_model, updated_system_prompt, updated_temperature, updated_tools)
st.session_state.editing_agent = None
def update_agent(agent_id, name, model, system_prompt, temperature, tools):
response = requests.put(f"{API_BASE_URL}/agents/{agent_id}", json={
"name": name,
"model": model,
"system_prompt": system_prompt,
"temperature": temperature,
"selected_tools": tools
})
if response.status_code == 200:
st.success("Agent updated successfully!")
st.session_state.fetch_agents()
else:
st.error("Failed to update agent.")
def delete_agent(agent_id):
response = requests.delete(f"{API_BASE_URL}/agents/{agent_id}")
if response.status_code == 200:
st.success("Agent deleted successfully!")
st.session_state.fetch_agents()
else:
st.error("Failed to delete agent.")

View File

@ -1,34 +1,45 @@
import streamlit as st
from core.ui.thread_management import display_thread_management
from core.ui.message_display import display_messages
from core.ui.thread_runner import display_thread_runner
from core.ui.agent_management import display_agent_management
from core.ui.message_display import display_messages_and_runner
from core.ui.thread_runner import fetch_thread_runs, display_runs
from core.ui.tool_display import display_tools
from core.ui.utils import initialize_session_state, fetch_data
from core.ui.utils import initialize_session_state, fetch_data, API_BASE_URL
def main():
initialize_session_state()
fetch_data()
st.set_page_config(page_title="AI Assistant Management System", layout="wide")
st.sidebar.title("Navigation")
mode = st.sidebar.radio("Select Mode", ["Agent Management", "Thread Management", "Tools"])
mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"])
st.title("AI Assistant Management System")
if mode == "Agent Management":
display_agent_management()
elif mode == "Tools":
if mode == "Tools":
display_tools()
else: # Thread Management
display_thread_management_content()
def display_thread_management_content():
st.header("Thread Management")
display_thread_management()
col1, col2 = st.columns([1, 3])
with col1:
display_thread_management()
if st.session_state.selected_thread:
display_thread_runner(st.session_state.selected_thread)
with col2:
if st.session_state.selected_thread:
display_messages_and_runner(st.session_state.selected_thread)
if st.session_state.selected_thread:
display_messages(st.session_state.selected_thread)
display_thread_runner(st.session_state.selected_thread)
def display_thread_runner(thread_id):
st.subheader("Thread Runs")
limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20)
if st.button("Fetch Runs"):
runs = fetch_thread_runs(thread_id, limit)
display_runs(runs)
if __name__ == "__main__":
main()

View File

@ -1,16 +1,21 @@
import streamlit as st
import requests
import json
from core.ui.utils import API_BASE_URL
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
from core.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content
def display_messages(thread_id):
st.subheader(f"🧵 Thread ID: {thread_id}")
# st.write("### 📝 Messages")
def display_messages_and_runner(thread_id):
st.subheader(f"Messages for Thread: {thread_id}")
messages = fetch_messages(thread_id)
display_message_json(messages)
display_message_list(messages)
display_add_message_form(thread_id)
col1, col2 = st.columns(2)
with col1:
display_add_message_form(thread_id)
with col2:
display_run_thread_form(thread_id)
def fetch_messages(thread_id):
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
@ -20,28 +25,37 @@ def fetch_messages(thread_id):
st.error("Failed to fetch messages.")
return []
def display_message_json(messages):
if st.button("Show/Hide JSON"):
st.session_state.show_json = not st.session_state.get('show_json', False)
if st.session_state.get('show_json', False):
json_str = json.dumps(messages, indent=2)
st.code(json_str, language="json")
def display_message_list(messages):
for msg in messages:
with st.chat_message(msg['role']):
st.write(msg['content'])
def display_add_message_form(thread_id):
st.write("### Add a New Message")
st.write("### Add a New Message")
with st.form(key="add_message_form"):
role = st.selectbox("🔹 Role", ["user", "assistant"], key="add_role")
content = st.text_area("📝 Content", key="add_content")
submitted = st.form_submit_button(" Add Message")
role = st.selectbox("Role", ["user", "assistant"], key="add_role")
content = st.text_area("Content", key="add_content")
submitted = st.form_submit_button("Add Message")
if submitted:
add_message(thread_id, role, content)
def display_run_thread_form(thread_id):
st.write("### Run Thread")
with st.form(key="run_thread_form"):
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100)
tool_options = list(st.session_state.tools.keys())
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
submitted = st.form_submit_button("Run Thread")
if submitted:
run_thread_data = prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools)
run_thread(thread_id, run_thread_data)
def add_message(thread_id, role, content):
message_data = {"role": role, "content": content}
add_msg_response = requests.post(
@ -49,6 +63,7 @@ def add_message(thread_id, role, content):
json=message_data
)
if add_msg_response.status_code == 200:
st.success("Message added successfully.")
st.rerun()
else:
st.error("Failed to add message.")

View File

@ -1,22 +1,22 @@
import streamlit as st
import requests
from core.ui.utils import API_BASE_URL
from datetime import datetime
def display_thread_management():
col1, col2 = st.columns([1, 2])
st.subheader("Thread Management")
with col1:
if st.button(" Create New Thread"):
create_new_thread()
if st.button(" Create New Thread", key="create_thread_button"):
create_new_thread()
with col2:
display_thread_selector()
display_thread_selector()
def create_new_thread():
response = requests.post(f"{API_BASE_URL}/threads/")
if response.status_code == 200:
thread_id = response.json()['thread_id']
st.session_state.selected_thread = thread_id
st.success(f"New thread created with ID: {thread_id}")
st.rerun()
else:
st.error("Failed to create a new thread.")
@ -25,17 +25,36 @@ def display_thread_selector():
threads_response = requests.get(f"{API_BASE_URL}/threads/")
if threads_response.status_code == 200:
threads = threads_response.json()
thread_options = [str(thread['thread_id']) for thread in threads]
# Sort threads by creation date if available, otherwise by thread_id
def sort_key(thread):
if 'creation_date' in thread:
try:
return datetime.strptime(thread['creation_date'], "%Y-%m-%d %H:%M:%S")
except ValueError:
st.warning(f"Invalid date format for thread {thread['thread_id']}")
return thread['thread_id']
sorted_threads = sorted(threads, key=sort_key, reverse=True)
thread_options = []
for thread in sorted_threads:
if 'creation_date' in thread:
thread_options.append(f"{thread['thread_id']} - Created: {thread['creation_date']}")
else:
thread_options.append(thread['thread_id'])
if st.session_state.selected_thread is None and threads:
st.session_state.selected_thread = str(threads[0]['thread_id'])
if st.session_state.selected_thread is None and sorted_threads:
st.session_state.selected_thread = sorted_threads[0]['thread_id']
selected_thread = st.selectbox(
"🔍 Select Thread",
thread_options,
key="thread_select",
index=thread_options.index(str(st.session_state.selected_thread)) if st.session_state.selected_thread else 0
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
)
if selected_thread:
st.session_state.selected_thread = int(selected_thread)
st.session_state.selected_thread = selected_thread.split(' - ')[0] if ' - ' in selected_thread else selected_thread
else:
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")

View File

@ -1,66 +1,16 @@
import streamlit as st
import requests
import json
from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
from core.ui.utils import API_BASE_URL
from datetime import datetime
def display_thread_runner(thread_id):
st.write("## ⚙️ Run Thread")
manual_config = display_manual_setup_tab()
# Common settings
additional_instructions = st.text_area("Additional Instructions", key="additional_instructions", height=100)
stream = st.checkbox("📡 Stream Responses", key="stream_responses")
# Prepare the run thread data
run_thread_data = prepare_run_thread_data(manual_config, additional_instructions, stream)
# Display the preview of the request payload in an expander
with st.expander("📤 Preview Request Payload", expanded=False):
st.json(run_thread_data)
# Center the run button and make it more prominent
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
if st.button("▶️ Run Thread", key="run_thread_button", use_container_width=True):
run_thread(thread_id, run_thread_data)
display_thread_run_status(thread_id)
def display_manual_setup_tab():
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
col1, col2 = st.columns(2)
with col1:
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
with col2:
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
tool_options = list(st.session_state.tools.keys())
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
manual_config = {
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
return {
"system_message": {"role": "system", "content": system_message},
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"system_message": system_message,
"selected_tools": selected_tools
}
return manual_config
def prepare_run_thread_data(manual_config, additional_instructions, stream):
tools = [st.session_state.tools[tool]['schema'] for tool in manual_config['selected_tools'] if tool in st.session_state.tools]
return {
"system_message": {"role": "system", "content": manual_config['system_message']},
"model_name": manual_config['model_name'],
"temperature": manual_config['temperature'],
"max_tokens": manual_config['max_tokens'],
"tools": tools,
"additional_instructions": additional_instructions,
"stream": stream
"tools": selected_tools,
"additional_system_message": additional_system_message
}
def run_thread(thread_id, run_thread_data):
@ -73,49 +23,77 @@ def run_thread(thread_id, run_thread_data):
response_data = run_thread_response.json()
st.success("Thread run completed successfully!")
# Display the return payload in an expander
with st.expander("📥 Return Payload", expanded=False):
st.json(response_data)
if 'id' in response_data:
st.session_state.latest_run_id = response_data['id']
# Display the actual response content
st.write("### 📬 Response Content")
st.subheader("Response Content")
display_response_content(response_data)
st.rerun()
else:
st.error("Failed to run thread.")
with st.expander("❌ Error Response", expanded=True):
st.json(run_thread_response.json())
st.error(f"Failed to run thread. Status code: {run_thread_response.status_code}")
st.text("Response content:")
st.text(run_thread_response.text)
def display_response_content(response_data):
if isinstance(response_data, dict) and 'response' in response_data:
for item in response_data['response']:
if isinstance(item, dict):
if 'content' in item:
st.markdown(item['content'])
elif 'tool_calls' in item:
st.write("**Tool Calls:**")
for tool_call in item['tool_calls']:
st.write(f"- Function: `{tool_call['function']['name']}`")
st.code(tool_call['function']['arguments'], language="json")
elif isinstance(item, str):
st.markdown(item)
if isinstance(response_data, dict) and 'choices' in response_data:
message = response_data['choices'][0]['message']
st.write(f"**Role:** {message['role']}")
st.write(f"**Content:** {message['content']}")
if 'tool_calls' in message:
st.write("**Tool Calls:**")
for tool_call in message['tool_calls']:
st.write(f"- Function: `{tool_call['function']['name']}`")
st.code(tool_call['function']['arguments'], language="json")
else:
st.json(response_data)
def display_thread_run_status(thread_id):
status_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/run/status/")
if status_response.status_code == 200:
status_data = status_response.json()
st.write("### ⚙️ Thread Run Status")
status = status_data.get('status')
if status == 'completed':
st.success(f"**Status:** {status}")
elif status == 'error':
st.error(f"**Status:** {status}")
with st.expander("Error Details", expanded=True):
st.code(status_data.get('error_message'), language="")
else:
st.info(f"**Status:** {status}")
def fetch_thread_runs(thread_id, limit):
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
if response.status_code == 200:
return response.json()
else:
st.warning("Could not retrieve thread run status.")
st.error("Failed to retrieve runs.")
return []
def format_timestamp(timestamp):
if timestamp:
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
return 'N/A'
def display_runs(runs):
for run in runs:
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
with col2:
st.write(f"**Model:** {run['model']}")
st.write(f"**Temperature:** {run['temperature']}")
st.write(f"**Top P:** {run['top_p']}")
st.write(f"**Max Tokens:** {run['max_tokens']}")
st.write(f"**Tool Choice:** {run['tool_choice']}")
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
st.write("**System Message:**")
st.json(run['system_message'])
if run['tools']:
st.write("**Tools:**")
st.json(run['tools'])
if run['usage']:
st.write("**Usage:**")
st.json(run['usage'])
if run['response_format']:
st.write("**Response Format:**")
st.json(run['response_format'])
if run['last_error']:
st.error("**Last Error:**")
st.code(run['last_error'])

View File

@ -4,13 +4,6 @@ from core.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE
API_BASE_URL = "http://localhost:8000"
def fetch_agents():
response = requests.get(f"{API_BASE_URL}/agents/")
if response.status_code == 200:
st.session_state.agents = response.json()
else:
st.error("Failed to fetch agents.")
def fetch_tools():
response = requests.get(f"{API_BASE_URL}/tools/")
if response.status_code == 200:
@ -21,15 +14,10 @@ def fetch_tools():
def initialize_session_state():
if 'selected_thread' not in st.session_state:
st.session_state.selected_thread = None
if 'agents' not in st.session_state:
st.session_state.agents = []
if 'tools' not in st.session_state:
st.session_state.tools = []
if 'fetch_agents' not in st.session_state:
st.session_state.fetch_agents = fetch_agents
if 'fetch_tools' not in st.session_state:
st.session_state.fetch_tools = fetch_tools
def fetch_data():
fetch_agents()
fetch_tools()

BIN
main.db

Binary file not shown.

0
main.py Normal file
View File

87
poetry.lock generated
View File

@ -417,6 +417,26 @@ files = [
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]]
name = "fastapi"
version = "0.115.0"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.8"
files = [
{file = "fastapi-0.115.0-py3-none-any.whl", hash = "sha256:17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631"},
{file = "fastapi-0.115.0.tar.gz", hash = "sha256:f93b4ca3529a8ebc6fc3fcf710e5efa8de3df9b41570958abf1d97d843138004"},
]
[package.dependencies]
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
starlette = ">=0.37.2,<0.39.0"
typing-extensions = ">=4.8.0"
[package.extras]
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"]
[[package]]
name = "filelock"
version = "3.16.1"
@ -2284,6 +2304,53 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "2.1.3"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.8"
files = [
{file = "sse_starlette-2.1.3-py3-none-any.whl", hash = "sha256:8ec846438b4665b9e8c560fcdea6bc8081a3abf7942faa95e5a744999d219772"},
{file = "sse_starlette-2.1.3.tar.gz", hash = "sha256:9cd27eb35319e1414e3d2558ee7414487f9529ce3b3cf9b21434fd110e017169"},
]
[package.dependencies]
anyio = "*"
starlette = "*"
uvicorn = "*"
[package.extras]
examples = ["fastapi"]
[[package]]
name = "sseclient-py"
version = "1.7.2"
description = "SSE client for Python"
optional = false
python-versions = "*"
files = [
{file = "sseclient-py-1.7.2.tar.gz", hash = "sha256:ba3197d314766eccb72a1dda80b5fa14a0fbba07d796a287654c07edde88fe0f"},
{file = "sseclient_py-1.7.2-py2.py3-none-any.whl", hash = "sha256:a758653b13b78df42cdb696740635a26cb72ad433b75efb68dbbb163d099b6a9"},
]
[[package]]
name = "starlette"
version = "0.38.6"
description = "The little ASGI library that shines."
optional = false
python-versions = ">=3.8"
files = [
{file = "starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05"},
{file = "starlette-0.38.6.tar.gz", hash = "sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead"},
]
[package.dependencies]
anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
[[package]]
name = "streamlit"
version = "1.39.0"
@ -2615,6 +2682,24 @@ h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "uvicorn"
version = "0.31.0"
description = "The lightning-fast ASGI server."
optional = false
python-versions = ">=3.8"
files = [
{file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"},
{file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"},
]
[package.dependencies]
click = ">=7.0"
h11 = ">=0.8"
[package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "watchdog"
version = "5.0.3"
@ -2784,4 +2869,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "817f3d510c67d6a59033e26f11a247490013e3e63c6184844d6b261f43352cfd"
content-hash = "11cd10192810fd94515b1f6ad5d8ce27b5d665d1d4e65fff9b05b4cd80c9d377"

View File

@ -21,6 +21,9 @@ litellm = "^1.44.4"
pytest = "^8.3.2"
pytest-asyncio = "^0.24.0"
agentops = "^0.3.10"
sseclient-py = "1.7.2"
fastapi = "^0.115.0"
sse-starlette = "^2.1.3"
[tool.poetry.group.dev.dependencies]

View File

@ -1,7 +1,7 @@
import os
import asyncio
from typing import Dict, Any
from core.tools.tool import Tool, ToolResult
from core.tool import Tool, ToolResult
from core.config import settings
class FilesTool(Tool):

View File

@ -1,5 +1,5 @@
from typing import Dict, Any
from core.tools.tool import Tool, ToolResult
from core.tool import Tool, ToolResult
class ExampleTool(Tool):
description = "An example tool for demonstration purposes."

1
workspace/hello.txt Normal file
View File

@ -0,0 +1 @@
Hello there! This is a sample file created just for you.

1
workspace/info.txt Normal file
View File

@ -0,0 +1 @@
This file contains random information about this assistant. I'm here to help you with a wide range of tasks, from answering questions to managing files.

View File

@ -1 +0,0 @@
This is some random content for the file.

View File

@ -1 +0,0 @@
random contents

4
workspace/tips.txt Normal file
View File

@ -0,0 +1,4 @@
Here are some random tips:
1. Stay curious and keep learning.
2. Organize your workspace for better productivity.
3. Take breaks to maintain focus and energy.