This commit is contained in:
marko-kraemer 2024-10-23 03:28:12 +02:00
parent 21e06f3162
commit 3f69ea9cc4
15 changed files with 273 additions and 1355 deletions

View File

@ -1,10 +1,10 @@
from .config import settings
from .db import Database, Thread, ThreadRun
from .db import Database, Thread
from .llm import make_llm_api_call
from .thread_manager import ThreadManager
# from .working_memory_manager import WorkingMemory
__all__ = [
'settings', 'Database', 'Thread', 'ThreadRun',
'settings', 'Database', 'Thread',
'make_llm_api_call', 'ThreadManager'
] #'WorkingMemory'

View File

@ -1,182 +0,0 @@
from fastapi import FastAPI, HTTPException, Query, Path
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import asyncio
import json
from agentpress.db import Database
from agentpress.thread_manager import ThreadManager
from agentpress.tool_registry import ToolRegistry
from agentpress.config import Settings
app = FastAPI(
title="Thread Manager API",
description="API for managing and running threads with LLM integration",
version="1.0.0",
)
db = Database()
manager = ThreadManager(db)
tool_registry = ToolRegistry()
class Message(BaseModel):
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
content: str = Field(..., description="The content of the message")
class RunThreadRequest(BaseModel):
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
model_name: str = Field(..., description="The name of the LLM model to be used")
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message")
additional_message: Optional[Dict[str, Any]] = Field(None, description="Additional message to be appended at the end of the conversation")
hide_tool_msgs: bool = Field(False, description="Whether to hide tool messages in the conversation history")
execute_tools_async: bool = Field(True, description="Whether to execute tools asynchronously")
use_tool_parser: bool = Field(False, description="Whether to use the tool parser for handling tool calls")
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
autonomous_iterations_amount: Optional[int] = Field(None, description="The number of autonomous iterations to perform")
continue_instructions: Optional[str] = Field(None, description="Instructions for continuing the conversation in subsequent iterations")
initializer: Optional[str] = Field(None, description="Name of the initializer function")
pre_iteration: Optional[str] = Field(None, description="Name of the pre-iteration function")
after_iteration: Optional[str] = Field(None, description="Name of the after-iteration function")
finalizer: Optional[str] = Field(None, description="Name of the finalizer function")
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
async def create_thread():
"""
Create a new thread and return its ID.
"""
thread_id = await manager.create_thread()
return {"thread_id": thread_id}
@app.get("/threads/", response_model=List[Dict[str, Any]], summary="Get all threads")
async def get_threads():
"""
Retrieve a list of all threads.
"""
threads = await manager.get_threads()
return [{"thread_id": thread.thread_id, "created_at": thread.created_at} for thread in threads]
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
async def add_message(thread_id: str, message: Message):
"""
Add a new message to the specified thread.
"""
await manager.add_message(thread_id, message.dict())
return {"status": "success"}
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
async def list_messages(thread_id: str):
"""
Retrieve all messages from the specified thread.
"""
messages = await manager.list_messages(thread_id)
return messages
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
async def run_thread(thread_id: str, request: RunThreadRequest):
try:
# Create a new ThreadRun object
thread_run = await manager.create_thread_run(
thread_id,
model_name=request.model_name,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
tool_choice=request.tool_choice,
execute_tools_async=request.execute_tools_async,
system_message=json.dumps(request.system_message),
tools=json.dumps(request.tools),
response_format=json.dumps(request.response_format),
autonomous_iterations_amount=request.autonomous_iterations_amount,
continue_instructions=request.continue_instructions
)
# Run the thread with the created ThreadRun object
result = await manager.run_thread(
thread_id=thread_id,
thread_run=thread_run,
initializer=get_function(request.initializer),
pre_iteration=get_function(request.pre_iteration),
after_iteration=get_function(request.after_iteration),
finalizer=get_function(request.finalizer),
**request.dict(exclude={'initializer', 'pre_iteration', 'after_iteration', 'finalizer'})
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def get_function(function_name: Optional[str]):
if function_name is None:
return None
# Implement a way to get the function by name, e.g., from a predefined dictionary of functions
# For now, we'll return None
return None
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
async def get_tools():
"""
Retrieve a list of all available tools and their schemas.
"""
tools = tool_registry.get_all_tools()
if not tools:
print("No tools found in the registry") # Debug print
return {
name: {
"name": name,
"description": tool_info['schema']['function']['description'],
"schema": tool_info['schema']
}
for name, tool_info in tools.items()
}
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
async def get_run(
thread_id: str = Path(..., description="The ID of the thread that was run"),
run_id: str = Path(..., description="The ID of the run to retrieve")
):
run = await manager.get_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
async def cancel_run(
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
run_id: str = Path(..., description="The ID of the run to cancel")
):
"""
Cancels a run that is in_progress.
"""
run = await manager.cancel_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found")
return run
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
async def list_runs(
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
):
runs = await manager.list_runs(thread_id, limit)
return runs
@app.post("/threads/{thread_id}/runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop a thread run")
async def stop_thread_run(
thread_id: str = Path(..., description="The ID of the thread"),
run_id: str = Path(..., description="The ID of the run to stop")
):
"""
Stops a thread run that is in progress.
"""
run = await manager.stop_thread_run(thread_id, run_id)
if run is None:
raise HTTPException(status_code=404, detail="Run not found or already completed/stopped")
return run
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -17,43 +17,10 @@ class Thread(Base):
messages = Column(Text)
created_at = Column(Integer)
runs = relationship("ThreadRun", back_populates="thread")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.created_at = int(datetime.utcnow().timestamp())
class ThreadRun(Base):
__tablename__ = "thread_runs"
id = Column(String, primary_key=True)
thread_id = Column(String, ForeignKey("threads.thread_id"))
status = Column(String, default="queued")
last_error = Column(String, nullable=True)
created_at = Column(Integer)
started_at = Column(Integer, nullable=True)
cancelled_at = Column(Integer, nullable=True)
failed_at = Column(Integer, nullable=True)
completed_at = Column(Integer, nullable=True)
model = Column(String)
temperature = Column(Float)
max_tokens = Column(Integer, nullable=True)
top_p = Column(Float, nullable=True)
tool_choice = Column(String)
execute_tools_async = Column(Boolean)
system_message = Column(JSON)
tools = Column(JSON, nullable=True)
usage = Column(JSON, nullable=True)
response_format = Column(JSON, nullable=True)
autonomous_iterations_amount = Column(Integer, nullable=True)
continue_instructions = Column(String, nullable=True)
iterations = Column(JSON, nullable=True)
thread = relationship("Thread", back_populates="runs")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.created_at = int(datetime.utcnow().timestamp())
# class MemoryModule(Base):
# __tablename__ = 'memory_modules'

View File

@ -26,7 +26,7 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
litellm.set_verbose = True
async def attempt_api_call(api_call_func, max_attempts=3):
@ -69,14 +69,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
api_call_params["max_tokens"] = max_tokens
if tools:
if use_tool_parser:
# Add tools as user messages
tools_message = {"role": "user", "content": json.dumps(tools)}
api_call_params["messages"].append(tools_message)
else:
# Use the existing method of adding tools
api_call_params["tools"] = tools
api_call_params["tool_choice"] = tool_choice
# Use the existing method of adding tools
api_call_params["tools"] = tools
api_call_params["tool_choice"] = tool_choice
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
api_call_params["extra_headers"] = {

View File

@ -1,22 +1,27 @@
import json
import logging
import asyncio
from typing import List, Dict, Any, Optional, Callable
from typing import List, Dict, Any, Optional, Callable, Type
from sqlalchemy import select
from agentpress.db import Database, Thread, ThreadRun
from agentpress.db import Database, Thread
from agentpress.llm import make_llm_api_call
from datetime import datetime, UTC
from agentpress.tool import ToolResult
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
import uuid
from tools.files_tool import FilesTool
class ThreadManager:
def __init__(self, db: Database):
self.db = db
def __init__(self, db: Optional[Database] = None):
self.db = db if db is not None else Database()
self.tool_registry = ToolRegistry()
self.run_config: Dict[str, Any] = {}
self.current_iteration: int = 0
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None):
"""
Add a tool to the ThreadManager.
If function_names is provided, only register those specific functions.
If function_names is None, register all functions from the tool.
"""
self.tool_registry.register_tool(tool_class, function_names)
async def create_thread(self) -> int:
async with self.db.get_async_session() as session:
@ -25,7 +30,7 @@ class ThreadManager:
)
session.add(new_thread)
await session.commit()
await session.refresh(new_thread) # Ensure thread_id is populated
await session.refresh(new_thread)
return new_thread.thread_id
async def add_message(self, thread_id: int, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
@ -38,9 +43,7 @@ class ThreadManager:
try:
messages = json.loads(thread.messages)
# If we're adding a user message, perform checks
if message_data['role'] == 'user':
# Find the last assistant message with tool calls
last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
if last_assistant_index is not None:
@ -50,12 +53,10 @@ class ThreadManager:
if tool_call_count != tool_response_count:
await self.cleanup_incomplete_tool_calls(thread_id)
# Convert ToolResult objects to strings
for key, value in message_data.items():
if isinstance(value, ToolResult):
message_data[key] = str(value)
# Process images if present
if images:
if isinstance(message_data['content'], str):
message_data['content'] = [{"type": "text", "text": message_data['content']}]
@ -81,50 +82,6 @@ class ThreadManager:
logging.error(f"Failed to add message to thread {thread_id}: {e}")
raise e
async def get_message(self, thread_id: int, message_index: int) -> Optional[Dict[str, Any]]:
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
return None
messages = json.loads(thread.messages)
if message_index < len(messages):
return messages[message_index]
return None
async def modify_message(self, thread_id: int, message_index: int, new_message_data: Dict[str, Any]):
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
raise ValueError(f"Thread with id {thread_id} not found")
try:
messages = json.loads(thread.messages)
if message_index < len(messages):
messages[message_index] = new_message_data
thread.messages = json.dumps(messages)
await session.commit()
else:
raise ValueError(f"Message index {message_index} is out of range")
except Exception as e:
await session.rollback()
raise e
async def remove_message(self, thread_id: int, message_index: int):
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
raise ValueError(f"Thread with id {thread_id} not found")
try:
messages = json.loads(thread.messages)
if message_index < len(messages):
del messages[message_index]
thread.messages = json.dumps(messages)
await session.commit()
except Exception as e:
await session.rollback()
raise e
async def list_messages(self, thread_id: int, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
@ -138,7 +95,7 @@ class ThreadManager:
return [msg]
return []
filtered_messages = messages # Initialize filtered_messages with all messages
filtered_messages = messages
if hide_tool_msgs:
filtered_messages = [
@ -164,7 +121,6 @@ class ThreadManager:
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool']
if len(tool_calls) != len(tool_responses):
# Create failed ToolResults for incomplete tool calls
failed_tool_results = []
for tool_call in tool_calls[len(tool_responses):]:
failed_tool_result = {
@ -175,7 +131,6 @@ class ThreadManager:
}
failed_tool_results.append(failed_tool_result)
# Insert failed tool results after the last assistant message
assistant_index = messages.index(last_assistant_message)
messages[assistant_index+1:assistant_index+1] = failed_tool_results
@ -188,236 +143,99 @@ class ThreadManager:
return True
return False
async def run_thread(self, settings: Dict[str, Any]) -> Dict[str, Any]:
try:
thread_run = ThreadRun(
id=str(uuid.uuid4()),
thread_id=settings['thread_id'],
status="queued",
model=settings['model_name'],
temperature=settings.get('temperature', 0.7),
max_tokens=settings.get('max_tokens'),
top_p=settings.get('top_p'),
tool_choice=settings.get('tool_choice', 'auto'),
execute_tools_async=settings.get('execute_tools_async', True),
system_message=json.dumps(settings['system_message']),
tools=json.dumps(settings.get('tools')),
response_format=json.dumps(settings.get('response_format')),
autonomous_iterations_amount=settings.get('autonomous_iterations_amount', 1),
continue_instructions=settings.get('continue_instructions')
)
async with self.db.get_async_session() as session:
session.add(thread_run)
await session.commit()
thread_run.status = "in_progress"
thread_run.started_at = int(datetime.now(UTC).timestamp())
await self.update_thread_run(thread_run)
self.run_config = {k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')}
self.run_config['iterations'] = []
if settings.get('initializer'):
settings['initializer']()
# Update thread_run with changes from run_config
for key, value in self.run_config.items():
setattr(thread_run, key, value)
await self.update_thread_run(thread_run)
full_tools = None
if settings.get('tools'):
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in settings['tools'] if self.tool_registry.get_tool(tool_name)]
self.current_iteration = 0
for iteration in range(settings.get('autonomous_iterations_amount', 1)):
self.current_iteration = iteration + 1
if await self.should_stop(settings['thread_id'], thread_run.id):
thread_run.status = "stopped"
thread_run.cancelled_at = int(datetime.now(UTC).timestamp())
await self.update_thread_run(thread_run)
return {"status": "stopped", "message": "Thread run cancelled"}
if settings.get('pre_iteration'):
settings['pre_iteration']()
# Update thread_run with changes from run_config
for key, value in self.run_config.items():
setattr(thread_run, key, value)
await self.update_thread_run(thread_run)
if iteration > 0 and settings.get('continue_instructions'):
await self.add_message(settings['thread_id'], {"role": "user", "content": settings['continue_instructions']})
messages = await self.list_messages(settings['thread_id'], hide_tool_msgs=settings.get('hide_tool_msgs', False))
prepared_messages = [settings['system_message']] + messages
if settings.get('additional_message'):
prepared_messages.append(settings['additional_message'])
response = await make_llm_api_call(
prepared_messages,
settings['model_name'],
temperature=thread_run.temperature,
max_tokens=thread_run.max_tokens,
tools=full_tools,
tool_choice=thread_run.tool_choice,
stream=False,
top_p=thread_run.top_p,
response_format=json.loads(thread_run.response_format) if thread_run.response_format else None
)
usage = response.usage if hasattr(response, 'usage') else None
usage_dict = self.serialize_usage(usage) if usage else None
thread_run.usage = usage_dict
assistant_message = {
"role": "assistant",
"content": response.choices[0].message['content']
}
if 'tool_calls' in response.choices[0].message:
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
await self.add_message(settings['thread_id'], assistant_message)
if settings.get('tools') is None or settings.get('use_tool_parser', False):
await self.handle_response_without_tools(settings['thread_id'], response, settings.get('use_tool_parser', False))
else:
await self.handle_response_with_tools(settings['thread_id'], response, settings.get('execute_tools_async', True))
self.run_config['iterations'].append({
"iteration": self.current_iteration,
"response": self.serialize_choice(response.choices[0]),
"usage": usage_dict
})
if settings.get('after_iteration'):
settings['after_iteration']()
# Update thread_run with changes from run_config
for key, value in self.run_config.items():
setattr(thread_run, key, value)
thread_run.iterations = json.dumps(self.run_config['iterations'])
await self.update_thread_run(thread_run)
thread_run.status = "completed"
thread_run.completed_at = int(datetime.now(UTC).timestamp())
await self.update_thread_run(thread_run)
self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')})
if settings.get('finalizer'):
settings['finalizer']()
# Update thread_run with final changes from run_config
for key, value in self.run_config.items():
setattr(thread_run, key, value)
await self.update_thread_run(thread_run)
return {
"id": thread_run.id,
"status": thread_run.status,
"iterations": self.run_config['iterations'],
"total_iterations": len(self.run_config['iterations']),
"usage": thread_run.usage,
"model": settings['model_name'],
"object": "chat.completion",
"created": int(datetime.now(UTC).timestamp())
}
except Exception as e:
thread_run.status = "failed"
thread_run.failed_at = int(datetime.now(UTC).timestamp())
thread_run.last_error = str(e)
await self.update_thread_run(thread_run)
self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')})
if settings.get('finalizer'):
settings['finalizer']()
raise
async def update_thread_run(self, thread_run: ThreadRun):
async with self.db.get_async_session() as session:
session.add(thread_run)
await session.commit()
await session.refresh(thread_run)
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
response_content = response.choices[0].message['content']
async def run_thread(self, thread_id: int, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = True) -> Dict[str, Any]:
if use_tool_parser:
await self.handle_tool_parser_response(thread_id, response_content)
else:
# The message has already been added in the run_thread method, so we don't need to add it again here
pass
messages = await self.list_messages(thread_id)
prepared_messages = [system_message] + messages
if additional_message:
prepared_messages.append(additional_message)
tools = self.tool_registry.get_all_tool_schemas() if use_tools else None
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
if tool_call_match:
try:
tool_call_json = json.loads(tool_call_match.group())
tool_calls = tool_call_json.get('function_calls', [])
assistant_message = {
"role": "assistant",
"content": response_content,
"tool_calls": [
{
"id": f"call_{i}",
"type": "function",
"function": {
"name": call['name'],
"arguments": json.dumps(call['arguments'])
}
} for i, call in enumerate(tool_calls)
]
try:
llm_response = await make_llm_api_call(
prepared_messages,
model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice if use_tools else None,
stream=False
)
except Exception as e:
logging.error(f"Error in API call: {str(e)}")
return {
"status": "error",
"message": str(e),
"run_thread_params": {
"thread_id": thread_id,
"system_message": system_message,
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tool_choice": tool_choice,
"additional_message": additional_message,
"execute_tools_async": execute_tools_async,
"execute_model_tool_calls": execute_model_tool_calls,
"use_tools": use_tools
}
await self.add_message(thread_id, assistant_message)
}
available_functions = self.get_available_functions()
tool_results = await self.execute_tools(assistant_message['tool_calls'], available_functions, thread_id, execute_tools_async=True)
await self.process_tool_results(thread_id, tool_results)
except json.JSONDecodeError:
logging.error("Failed to parse tool call JSON from response")
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
if use_tools and execute_model_tool_calls:
await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async)
else:
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
await self.handle_response_without_tools(thread_id, llm_response)
return {
"llm_response": llm_response,
"run_thread_params": {
"thread_id": thread_id,
"system_message": system_message,
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tool_choice": tool_choice,
"additional_message": additional_message,
"execute_tools_async": execute_tools_async,
"execute_model_tool_calls": execute_model_tool_calls,
"use_tools": use_tools
}
}
async def handle_response_without_tools(self, thread_id: int, response: Any):
response_content = response.choices[0].message['content']
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
async def handle_response_with_tools(self, thread_id: int, response: Any, execute_tools_async: bool):
try:
response_message = response.choices[0].message
tool_calls = response_message.get('tool_calls', [])
# The assistant message has already been added in the run_thread method
assistant_message = self.create_assistant_message_with_tools(response_message)
await self.add_message(thread_id, assistant_message)
available_functions = self.get_available_functions()
if await self.should_stop(thread_id, thread_id):
return {"status": "stopped", "message": "Session cancelled"}
if tool_calls:
if execute_tools_async:
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
else:
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
# Add tool results to messages
for result in tool_results:
await self.add_message(thread_id, result)
if await self.should_stop(thread_id, thread_id):
return {"status": "stopped", "message": "Session cancelled after tool execution"}
except AttributeError as e:
logging.error(f"AttributeError: {e}")
# No need to add the message here as it's already been added in the run_thread method
response_content = response.choices[0].message['content']
await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""})
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
message = {
"role": "assistant",
"content": response_message.get('content') or "",
}
tool_calls = response_message.get('tool_calls')
if tool_calls:
message["tool_calls"] = [
@ -441,33 +259,12 @@ class ThreadManager:
available_functions[func_name] = getattr(tool_instance, func_name)
return available_functions
async def execute_tools(self, tool_calls: List[Any], available_functions: Dict[str, Callable], thread_id: int, execute_tools_async: bool) -> List[Dict[str, Any]]:
if execute_tools_async:
return await self.execute_tools_async(tool_calls, available_functions, thread_id)
else:
return await self.execute_tools_sync(tool_calls, available_functions, thread_id)
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
async def execute_single_tool(tool_call):
if await self.should_stop(thread_id, thread_id):
return {"status": "stopped", "message": "Session cancelled"}
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_call_id = tool_call.id
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
error_message = f"Error parsing arguments for {function_name}: {str(e)}"
logging.error(error_message)
logging.error(f"Problematic JSON: {tool_call.function.arguments}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": str(ToolResult(success=False, output=error_message)),
}
function_to_call = available_functions.get(function_name)
if function_to_call:
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
@ -481,9 +278,6 @@ class ThreadManager:
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
tool_results = []
for tool_call in tool_calls:
if await self.should_stop(thread_id, thread_id):
return [{"status": "stopped", "message": "Session cancelled"}]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_call_id = tool_call.id
@ -498,10 +292,6 @@ class ThreadManager:
return tool_results
async def process_tool_results(self, thread_id: int, tool_results: List[Dict[str, Any]]):
for result in tool_results:
await self.add_message(thread_id, result['tool_message'])
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
try:
function_response = await function_to_call(**function_args)
@ -516,297 +306,60 @@ class ThreadManager:
"content": str(function_response),
}
async def should_stop(self, thread_id: str, run_id: str) -> bool:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.status in ["stopped", "cancelled", "queued"]:
return True
return False
async def stop_thread_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id and run.status == "in_progress":
run.status = "stopping"
await session.commit()
return self.serialize_thread_run(run)
return None
async def save_thread_run(self, thread_id: str):
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
raise ValueError(f"Thread with id {thread_id} not found")
messages = json.loads(thread.messages)
creation_date = datetime.now().isoformat()
# Get the latest ThreadRun for this thread
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
result = await session.execute(stmt)
latest_thread_run = result.scalar_one_or_none()
if latest_thread_run:
# Update the existing ThreadRun
latest_thread_run.messages = json.dumps(messages)
latest_thread_run.last_updated_date = creation_date
await session.commit()
else:
# Create a new ThreadRun if none exists
new_thread_run = ThreadRun(
thread_id=thread_id,
messages=json.dumps(messages),
creation_date=creation_date,
status='completed'
)
session.add(new_thread_run)
await session.commit()
async def get_thread(self, thread_id: int) -> Optional[Thread]:
async with self.db.get_async_session() as session:
return await session.get(Thread, thread_id)
async def update_thread_run_with_error(self, thread_id: int, error_message: str):
async with self.db.get_async_session() as session:
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
result = await session.execute(stmt)
thread_run = result.scalar_one_or_none()
if thread_run:
thread_run.status = 'error'
thread_run.error_message = error_message # Store the full error message
await session.commit()
async def get_threads(self) -> List[Thread]:
async with self.db.get_async_session() as session:
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
return result.scalars().all()
async def get_latest_thread_run(self, thread_id: str):
async with self.db.get_async_session() as session:
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
result = await session.execute(stmt)
latest_run = result.scalar_one_or_none()
if latest_run:
return {
"id": latest_run.id,
"status": latest_run.status,
"error_message": latest_run.last_error,
"created_at": latest_run.created_at,
"started_at": latest_run.started_at,
"completed_at": latest_run.completed_at,
"cancelled_at": latest_run.cancelled_at,
"failed_at": latest_run.failed_at,
"model": latest_run.model,
"usage": latest_run.usage
}
return None
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id:
return {
"id": run.id,
"thread_id": run.thread_id,
"status": run.status,
"created_at": run.created_at,
"started_at": run.started_at,
"completed_at": run.completed_at,
"cancelled_at": run.cancelled_at,
"failed_at": run.failed_at,
"model": run.model,
"system_message": json.loads(run.system_message) if run.system_message else None,
"tools": json.loads(run.tools) if run.tools else None,
"usage": run.usage,
"temperature": run.temperature,
"top_p": run.top_p,
"max_tokens": run.max_tokens,
"tool_choice": run.tool_choice,
"execute_tools_async": run.execute_tools_async,
"response_format": json.loads(run.response_format) if run.response_format else None,
"last_error": run.last_error
}
return None
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id and run.status == "in_progress":
run.status = "cancelled"
run.cancelled_at = int(datetime.now(UTC).timestamp())
await session.commit()
return await self.get_run(thread_id, run_id)
return None
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
async with self.db.get_async_session() as session:
thread_runs_stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
thread_runs_result = await session.execute(thread_runs_stmt)
thread_runs = thread_runs_result.scalars().all()
return [self.serialize_thread_run(run) for run in thread_runs]
async def create_thread_run(self, thread_id: str, **kwargs) -> ThreadRun:
run_id = str(uuid.uuid4())
thread_run = ThreadRun(
id=run_id,
thread_id=thread_id,
status="queued",
model=kwargs.get('model_name'),
temperature=kwargs.get('temperature'),
max_tokens=kwargs.get('max_tokens'),
top_p=kwargs.get('top_p'),
tool_choice=kwargs.get('tool_choice', "auto"),
execute_tools_async=kwargs.get('execute_tools_async', True),
system_message=json.dumps(kwargs.get('system_message')),
tools=json.dumps(kwargs.get('tools')),
response_format=json.dumps(kwargs.get('response_format')),
autonomous_iterations_amount=kwargs.get('autonomous_iterations_amount'),
continue_instructions=kwargs.get('continue_instructions')
)
async with self.db.get_async_session() as session:
session.add(thread_run)
await session.commit()
return thread_run
async def get_thread_run_count(self, thread_id: str) -> int:
async with self.db.get_async_session() as session:
result = await session.execute(select(ThreadRun).filter_by(thread_id=thread_id))
return len(result.all())
async def get_thread_run_status(self, thread_id: str, run_id: str) -> Dict[str, Any]:
async with self.db.get_async_session() as session:
run = await session.get(ThreadRun, run_id)
if run and run.thread_id == thread_id:
return self.serialize_thread_run(run)
return None
def serialize_thread_run(self, run: ThreadRun) -> Dict[str, Any]:
return {
"id": run.id,
"thread_id": run.thread_id,
"status": run.status,
"created_at": run.created_at,
"started_at": run.started_at,
"completed_at": run.completed_at,
"cancelled_at": run.cancelled_at,
"failed_at": run.failed_at,
"model": run.model,
"temperature": run.temperature,
"max_tokens": run.max_tokens,
"top_p": run.top_p,
"tool_choice": run.tool_choice,
"execute_tools_async": run.execute_tools_async,
"system_message": json.loads(run.system_message) if run.system_message else None,
"tools": json.loads(run.tools) if run.tools else None,
"usage": run.usage,
"response_format": json.loads(run.response_format) if run.response_format else None,
"last_error": run.last_error,
"autonomous_iterations_amount": run.autonomous_iterations_amount,
"continue_instructions": run.continue_instructions,
"iterations": json.loads(run.iterations) if run.iterations else None
}
def serialize_usage(self, usage):
return {
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
}
def serialize_completion_tokens_details(self, details):
return {
"audio_tokens": details.audio_tokens,
"reasoning_tokens": details.reasoning_tokens
}
def serialize_prompt_tokens_details(self, details):
return {
"audio_tokens": details.audio_tokens,
"cached_tokens": details.cached_tokens
}
def serialize_choice(self, choice):
return {
"finish_reason": choice.finish_reason,
"index": choice.index,
"message": self.serialize_message(choice.message)
}
def serialize_message(self, message):
return {
"content": message.content,
"role": message.role,
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
}
def serialize_tool_call(self, tool_call):
return {
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
if __name__ == "__main__":
import asyncio
from agentpress.db import Database
from tools.files_tool import FilesTool
async def main():
db = Database()
manager = ThreadManager(db)
manager = ThreadManager()
manager.add_tool(FilesTool, ['create_file', 'read_file'])
thread_id = await manager.create_thread()
await manager.add_message(thread_id, {"role": "user", "content": "Let's have a conversation about artificial intelligence and create a file summarizing our discussion."})
await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"})
system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."}
model_name = "gpt-4o"
system_message = {"role": "system", "content": "You are an AI expert engaging in a conversation about artificial intelligence. You can also create and manage files."}
files_tool = FilesTool()
tool_schemas = files_tool.get_schemas()
# Test with tools
response_with_tools = await manager.run_thread(
thread_id=thread_id,
system_message=system_message,
model_name=model_name,
temperature=0.7,
max_tokens=150,
tool_choice="auto",
additional_message=None,
execute_tools_async=True,
execute_model_tool_calls=True,
use_tools=True
)
def initializer():
print("Initializing thread run...")
manager.run_config['temperature'] = 0.8
print("Response with tools:", response_with_tools)
def pre_iteration():
print(f"Preparing iteration {manager.current_iteration}...")
manager.run_config['max_tokens'] = 200 if manager.current_iteration > 3 else 150
# Test without tools
response_without_tools = await manager.run_thread(
thread_id=thread_id,
system_message=system_message,
model_name=model_name,
temperature=0.7,
max_tokens=150,
additional_message={"role": "user", "content": "What's the capital of France?"},
use_tools=False
)
def after_iteration():
print(f"Completed iteration {manager.current_iteration}. Status: {manager.run_config['status']}")
manager.run_config['continue_instructions'] = "Let's focus more on AI ethics in the next iteration and update our summary file."
def finalizer():
print(f"Thread run finished with status: {manager.run_config['status']}")
print(f"Final configuration: {manager.run_config}")
settings = {
"thread_id": thread_id,
"system_message": system_message,
"model_name": "gpt-4o",
"temperature": 0.7,
"max_tokens": 150,
"autonomous_iterations_amount": 3,
"continue_instructions": "Continue the conversation about AI, introducing new aspects or asking thought-provoking questions. Don't forget to update our summary file.",
"initializer": initializer,
"pre_iteration": pre_iteration,
"after_iteration": after_iteration,
"finalizer": finalizer,
"tools": list(tool_schemas.keys()),
"tool_choice": "auto"
}
response = await manager.run_thread(settings)
print(f"Thread run response: {response}")
print("Response without tools:", response_without_tools)
# List messages in the thread
messages = await manager.list_messages(thread_id)
print("\nFinal conversation:")
print("\nMessages in the thread:")
for msg in messages:
print(f"{msg['role'].capitalize()}: {msg['content']}")
asyncio.run(main())
# Run the async main function
asyncio.run(main())

View File

@ -1,23 +1,100 @@
from datetime import datetime
from typing import List, Dict, Any
"""
This module provides the foundation for creating and managing tools in the AgentPress system.
The tool system allows for easy creation of function-like tools that can be used by AI models.
It provides a way to define OpenAPI schemas for these tools, which can then be used to generate
appropriate function calls in the AI model's context.
Key components:
- ToolResult: A dataclass representing the result of a tool execution.
- Tool: An abstract base class that all tools should inherit from.
- tool_schema: A decorator for easily defining OpenAPI schemas for tool methods.
Usage:
1. Create a new tool by subclassing Tool.
2. Define methods in your tool class and decorate them with @tool_schema.
3. The Tool class will automatically register these schemas.
4. Use the tool in your ThreadManager by adding it with add_tool method.
Example:
class MyTool(Tool):
@tool_schema({
"name": "add",
"description": "Add two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"}
},
"required": ["a", "b"]
}
})
async def add(self, a: float, b: float) -> ToolResult:
return self.success_response(f"The sum is {a + b}")
# In your thread manager:
manager.add_tool(MyTool)
"""
from typing import Dict, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from abc import ABC
import json
import inspect
@dataclass
class ToolResult:
"""
Represents the result of a tool execution.
Attributes:
success (bool): Whether the tool execution was successful.
output (str): The output of the tool execution.
"""
success: bool
output: str
class Tool(ABC):
def __init__(self):
pass
"""
Abstract base class for all tools.
This class provides the basic structure and functionality for tools.
Subclasses should implement specific tool methods decorated with @tool_schema.
Methods:
get_schemas(): Returns a dictionary of all registered tool schemas.
success_response(data): Creates a successful ToolResult.
fail_response(msg): Creates a failed ToolResult.
"""
def __init__(self):
self._schemas = {}
self._register_schemas()
def _register_schemas(self):
"""
Automatically registers schemas for all methods decorated with @tool_schema.
"""
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method, 'schema'):
self._schemas[name] = method.schema
@abstractmethod
def get_schemas(self) -> Dict[str, Dict[str, Any]]:
pass
"""
Returns a dictionary of all registered tool schemas, formatted for use with AI models.
"""
return self._schemas
def success_response(self, data: Dict[str, Any] | str) -> ToolResult:
"""
Creates a successful ToolResult with the given data.
Args:
data: The data to include in the success response.
Returns:
A ToolResult indicating success.
"""
if isinstance(data, str):
text = data
else:
@ -25,10 +102,47 @@ class Tool(ABC):
return ToolResult(success=True, output=text)
def fail_response(self, msg: str) -> ToolResult:
"""
Creates a failed ToolResult with the given error message.
Args:
msg: The error message to include in the failure response.
Returns:
A ToolResult indicating failure.
"""
return ToolResult(success=False, output=msg)
def format_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]:
return {
def tool_schema(schema: Dict[str, Any]):
"""
A decorator for easily defining OpenAPI schemas for tool methods.
This decorator allows you to define the schema for a tool method inline with the method definition.
It attaches the provided schema directly to the method.
Args:
schema (Dict[str, Any]): An OpenAPI schema describing the tool.
Example:
@tool_schema({
"name": "add",
"description": "Add two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"}
},
"required": ["a", "b"]
}
})
async def add(self, a: float, b: float) -> ToolResult:
return self.success_response(f"The sum is {a + b}")
"""
def decorator(func):
func.schema = {
"type": "function",
"function": schema
}
}
return func
return decorator

View File

@ -1,4 +1,4 @@
from typing import Dict, Type, Any, Optional
from typing import Dict, Type, Any, List, Optional
from agentpress.tool import Tool
from agentpress.config import settings
import importlib.util
@ -8,39 +8,34 @@ import inspect
class ToolRegistry:
def __init__(self):
self.tools: Dict[str, Dict[str, Any]] = {}
self.register_all_tools()
def register_tool(self, tool_cls: Type[Tool]):
def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None):
tool_instance = tool_cls()
schemas = tool_instance.get_schemas()
for func_name, schema in schemas.items():
self.tools[func_name] = {
"instance": tool_instance,
"schema": schema
}
if function_names is None:
# Register all functions
for func_name, schema in schemas.items():
self.tools[func_name] = {
"instance": tool_instance,
"schema": schema
}
else:
# Register only specified functions
for func_name in function_names:
if func_name in schemas:
self.tools[func_name] = {
"instance": tool_instance,
"schema": schemas[func_name]
}
else:
raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}")
def register_all_tools(self):
tools_dir = settings.tools_dir
for file in os.listdir(tools_dir):
if file.endswith('.py') and file not in ['__init__.py', 'tool.py', 'tool_registry.py']:
module_path = os.path.join(tools_dir, file)
module_name = os.path.splitext(file)[0]
try:
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Tool) and obj != Tool:
print(f"Registering tool: {name}") # Debug print
self.register_tool(obj)
except Exception as e:
print(f"Error importing {module_path}: {e}") # Debug print
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
return self.tools.get(tool_name)
def get_tool(self, tool_name: str) -> Dict[str, Any]:
return self.tools.get(tool_name, {})
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
return self.tools
return self.tools
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
return [tool_info['schema'] for tool_info in self.tools.values()]

View File

@ -1,4 +0,0 @@
from .main import main
from .tool_display import display_tools
__all__ = ['main', 'display_tools']

View File

@ -1,44 +0,0 @@
import streamlit as st
from agentpress.ui.thread_management import display_thread_management
from agentpress.ui.message_display import display_messages_and_runner
from agentpress.ui.thread_runner import fetch_thread_runs, display_runs
from agentpress.ui.tool_display import display_tools
from agentpress.ui.utils import initialize_session_state, fetch_data, API_BASE_URL
def main():
initialize_session_state()
fetch_data()
st.set_page_config(page_title="AI Assistant Management System", layout="wide")
st.sidebar.title("Navigation")
mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"])
st.title("AI Assistant Management System")
if mode == "Tools":
display_tools()
else: # Thread Management
display_thread_management_content()
def display_thread_management_content():
col1, col2 = st.columns([1, 3])
with col1:
display_thread_management()
if st.session_state.selected_thread:
display_thread_runner(st.session_state.selected_thread)
with col2:
if st.session_state.selected_thread:
display_messages_and_runner(st.session_state.selected_thread)
def display_thread_runner(thread_id):
limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20)
if st.button("Fetch Runs"):
runs = fetch_thread_runs(thread_id, limit)
# display_runs(runs)
if __name__ == "__main__":
main()

View File

@ -1,106 +0,0 @@
import streamlit as st
import requests
from agentpress.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
from agentpress.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content
def display_messages_and_runner(thread_id):
st.subheader(f"Messages for Thread: {thread_id}")
messages_container = st.empty()
def update_messages():
messages = fetch_messages(thread_id)
with messages_container.container():
display_message_list(messages)
update_messages()
display_add_message_form(thread_id, update_messages)
display_run_thread_form(thread_id, update_messages)
def fetch_messages(thread_id):
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
if messages_response.status_code == 200:
return messages_response.json()
else:
st.error("Failed to fetch messages.")
return []
def display_message_list(messages):
for msg in messages:
with st.chat_message(msg['role']):
st.write(msg['content'])
def display_add_message_form(thread_id, update_callback):
st.write("### Add a New Message")
with st.form(key="add_message_form"):
role = st.selectbox("Role", ["user", "assistant"], key="add_role")
content = st.text_area("Content", key="add_content")
submitted = st.form_submit_button("Add Message")
if submitted:
add_message(thread_id, role, content)
update_callback()
def display_run_thread_form(thread_id, update_callback):
st.write("### Run Thread")
with st.form(key="run_thread_form"):
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100)
tool_options = list(st.session_state.tools.keys())
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
autonomous_iterations_amount = st.number_input("Autonomous Iterations", min_value=1, max_value=10, value=1, key="autonomous_iterations_amount")
continue_instructions = st.text_area("Continue Instructions", key="continue_instructions", height=100, help="Instructions for continuing the conversation in subsequent iterations")
initializer = st.text_input("Initializer Function", key="initializer")
pre_iteration = st.text_input("Pre-iteration Function", key="pre_iteration")
after_iteration = st.text_input("After-iteration Function", key="after_iteration")
finalizer = st.text_input("Finalizer Function", key="finalizer")
submitted = st.form_submit_button("Run Thread")
if submitted:
run_thread_data = prepare_run_thread_data(
model_name, temperature, max_tokens, system_message,
additional_system_message, selected_tools,
autonomous_iterations_amount, continue_instructions,
initializer, pre_iteration, after_iteration, finalizer
)
response = run_thread(thread_id, run_thread_data)
if response:
st.subheader("Thread Run Response")
st.json(response)
# Update messages
update_callback()
def add_message(thread_id, role, content):
message_data = {"role": role, "content": content}
add_msg_response = requests.post(
f"{API_BASE_URL}/threads/{thread_id}/messages/",
json=message_data
)
if add_msg_response.status_code == 200:
st.success("Message added successfully.")
else:
st.error("Failed to add message.")
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions, initializer, pre_iteration, after_iteration, finalizer):
return {
"system_message": {"role": "system", "content": system_message},
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tools": selected_tools,
"additional_system_message": additional_system_message,
"tool_choice": "auto",
"autonomous_iterations_amount": autonomous_iterations_amount,
"continue_instructions": continue_instructions,
"initializer": initializer,
"pre_iteration": pre_iteration,
"after_iteration": after_iteration,
"finalizer": finalizer
}

View File

@ -1,22 +0,0 @@
import streamlit as st
import time
import requests
from agentpress.ui.utils import API_BASE_URL
def get_run_status(thread_id, run_id, is_agent_run):
endpoint = f"agent_runs" if is_agent_run else f"runs"
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/{endpoint}/{run_id}/status")
if response.status_code == 200:
return response.json()
return None
def real_time_status_update(thread_id, run_id, is_agent_run):
status_placeholder = st.empty()
while True:
status = get_run_status(thread_id, run_id, is_agent_run)
if status:
status_placeholder.write(f"Current status: {status['status']}")
if status['status'] in ['completed', 'failed', 'cancelled']:
break
time.sleep(1)
return status

View File

@ -1,87 +0,0 @@
import streamlit as st
import requests
from agentpress.ui.utils import API_BASE_URL
from datetime import datetime
from agentpress.ui.thread_runner import stop_thread_run, stop_agent_run, get_thread_run_status, get_agent_run_status
def display_thread_management():
st.subheader("Thread Management")
if st.button(" Create New Thread", key="create_thread_button"):
create_new_thread()
display_thread_selector()
if st.session_state.selected_thread:
display_run_history(st.session_state.selected_thread)
def create_new_thread():
response = requests.post(f"{API_BASE_URL}/threads/")
if response.status_code == 200:
thread_id = response.json()['thread_id']
st.session_state.selected_thread = thread_id
st.success(f"New thread created with ID: {thread_id}")
st.rerun()
else:
st.error("Failed to create a new thread.")
def display_thread_selector():
threads_response = requests.get(f"{API_BASE_URL}/threads/")
if threads_response.status_code == 200:
threads = threads_response.json()
# Sort threads by created_at timestamp (newest first)
sorted_threads = sorted(threads, key=lambda x: x['created_at'], reverse=True)
thread_options = [f"{thread['thread_id']} - Created: {format_timestamp(thread['created_at'])}" for thread in sorted_threads]
if st.session_state.selected_thread is None and sorted_threads:
st.session_state.selected_thread = sorted_threads[0]['thread_id']
selected_thread = st.selectbox(
"🔍 Select Thread",
thread_options,
key="thread_select",
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
)
if selected_thread:
st.session_state.selected_thread = selected_thread.split(' - ')[0]
else:
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")
def display_run_history(thread_id):
st.subheader("Run History")
# Fetch thread runs
thread_runs = fetch_thread_runs(thread_id)
# Display thread runs
st.write("### Thread Runs")
for run in thread_runs:
with st.expander(f"Run {run['id']} - Status: {run['status']}"):
st.write(f"Created At: {format_timestamp(run['created_at'])}")
st.write(f"Status: {run['status']}")
st.write(f"Iterations: {len(run.get('iterations', []))} / {run.get('autonomous_iterations_amount', 1)}")
if run['status'] == "in_progress":
if st.button(f"Stop Run {run['id']}", key=f"stop_thread_run_{run['id']}"):
stop_thread_run(thread_id, run['id'])
st.rerun()
if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_thread_run_{run['id']}"):
updated_run = get_thread_run_status(thread_id, run['id'])
if updated_run:
run.update(updated_run)
st.rerun()
def fetch_thread_runs(thread_id):
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs")
if response.status_code == 200:
return response.json()
else:
st.error("Failed to fetch thread runs.")
return []
def format_timestamp(timestamp):
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')

View File

@ -1,195 +0,0 @@
import streamlit as st
import requests
from agentpress.ui.utils import API_BASE_URL
from datetime import datetime
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
return {
"system_message": {"role": "system", "content": system_message},
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tools": selected_tools,
"additional_system_message": additional_system_message,
"tool_choice": "auto" # Add this line to ensure tool_choice is always set
}
def prepare_run_thread_agent_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions):
return {
"system_message": {"role": "system", "content": system_message},
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tools": selected_tools,
"additional_system_message": additional_system_message,
"autonomous_iterations_amount": autonomous_iterations_amount,
"continue_instructions": continue_instructions
}
def run_thread(thread_id, run_thread_data):
with st.spinner("Running thread..."):
try:
run_thread_response = requests.post(
f"{API_BASE_URL}/threads/{thread_id}/run/",
json=run_thread_data
)
run_thread_response.raise_for_status()
response_data = run_thread_response.json()
st.success(f"Thread run completed successfully! Status: {response_data.get('status', 'Unknown')}")
if 'id' in response_data:
st.session_state.latest_run_id = response_data['id']
st.subheader("Response Content")
display_response_content(response_data)
# Display the full response data
st.subheader("Full Response Data")
st.json(response_data)
return response_data # Return the response data
except requests.exceptions.RequestException as e:
st.error(f"Failed to run thread. Error: {str(e)}")
if hasattr(e, 'response') and e.response is not None:
st.text("Response content:")
st.text(e.response.text)
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
return None # Return None if there was an error
def run_thread_agent(thread_id, run_thread_agent_data):
with st.spinner("Running thread agent..."):
try:
run_thread_response = requests.post(
f"{API_BASE_URL}/threads/{thread_id}/run_agent/",
json=run_thread_agent_data
)
run_thread_response.raise_for_status()
response_data = run_thread_response.json()
st.success(f"Thread agent run completed successfully! Status: {response_data['status']}")
st.subheader("Agent Response")
display_agent_response_content(response_data)
# Display the full response data
st.subheader("Full Agent Response Data")
st.json(response_data)
return response_data
except requests.exceptions.RequestException as e:
st.error(f"Failed to run thread agent. Error: {str(e)}")
if hasattr(e, 'response') and e.response is not None:
st.text("Response content:")
st.text(e.response.text)
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
def display_response_content(response_data):
if isinstance(response_data, dict) and 'choices' in response_data:
message = response_data['choices'][0]['message']
st.write(f"**Role:** {message['role']}")
st.write(f"**Content:** {message['content']}")
if 'tool_calls' in message and message['tool_calls']:
st.write("**Tool Calls:**")
for tool_call in message['tool_calls']:
st.write(f"- Function: `{tool_call['function']['name']}`")
st.code(tool_call['function']['arguments'], language="json")
else:
st.json(response_data)
def display_agent_response_content(response_data):
st.write(f"**Status:** {response_data['status']}")
st.write(f"**Total Iterations:** {response_data['total_iterations']}")
st.write(f"**Completed Iterations:** {response_data.get('iterations_count', 'N/A')}")
for i, iteration in enumerate(response_data['iterations']):
with st.expander(f"Iteration {i+1}"):
display_response_content(iteration)
st.write("**Final Configuration:**")
st.json(response_data['final_config'])
def fetch_thread_runs(thread_id, limit):
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
if response.status_code == 200:
return response.json()
else:
st.error("Failed to retrieve runs.")
return []
def format_timestamp(timestamp):
if timestamp:
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
return 'N/A'
def display_runs(runs):
for run in runs:
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
with col2:
st.write(f"**Model:** {run['model']}")
st.write(f"**Temperature:** {run['temperature']}")
st.write(f"**Top P:** {run['top_p']}")
st.write(f"**Max Tokens:** {run['max_tokens']}")
st.write(f"**Tool Choice:** {run['tool_choice']}")
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
st.write(f"**Autonomous Iterations:** {run['autonomous_iterations_amount']}")
st.write("**System Message:**")
st.json(run['system_message'])
if run['tools']:
st.write("**Tools:**")
st.json(run['tools'])
if run['usage']:
st.write("**Usage:**")
st.json(run['usage'])
if run['response_format']:
st.write("**Response Format:**")
st.json(run['response_format'])
if run['last_error']:
st.error("**Last Error:**")
st.code(run['last_error'])
if run['continue_instructions']:
st.write("**Continue Instructions:**")
st.text(run['continue_instructions'])
if run['status'] == "in_progress":
if st.button(f"Stop Run {run['id']}", key=f"stop_button_{run['id']}"):
stop_thread_run(run['thread_id'], run['id'])
st.rerun()
if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_button_{run['id']}"):
updated_run = get_thread_run_status(run['thread_id'], run['id'])
if updated_run:
run.update(updated_run)
st.rerun()
def stop_thread_run(thread_id, run_id):
response = requests.post(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/stop")
if response.status_code == 200:
st.success("Thread run stopped successfully.")
return response.json()
else:
st.error(f"Failed to stop thread run. Status code: {response.status_code}")
return None
def get_thread_run_status(thread_id, run_id):
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/status")
if response.status_code == 200:
return response.json()
else:
st.error(f"Failed to get thread run status. Status code: {response.status_code}")
return None

View File

@ -1,43 +0,0 @@
import streamlit as st
import requests
from agentpress.ui.utils import API_BASE_URL
def display_tools():
st.header("Available Tools")
tools = fetch_tools()
if not tools:
st.warning("No tools available. Please check the API connection.")
return
view_mode = st.radio("View Mode", ["Simple", "Detailed"])
if view_mode == "Simple":
display_simple_view(tools)
else:
display_detailed_view(tools)
def display_simple_view(tools):
for tool_name, tool_info in tools.items():
with st.expander(f"🛠️ {tool_name}"):
st.write(f"**Description:** {tool_info['description']}")
def display_detailed_view(tools):
for tool_name, tool_info in tools.items():
with st.expander(f"🛠️ {tool_name}"):
st.write(f"**Description:** {tool_info['description']}")
if tool_info['schema']:
st.write("**Schema:**")
st.json(tool_info['schema'])
def fetch_tools():
response = requests.get(f"{API_BASE_URL}/tools/")
if response.status_code == 200:
tools = response.json()
st.session_state.tools = tools
return tools
else:
st.error(f"Failed to fetch tools. Status code: {response.status_code}")
st.error(f"Error message: {response.text}")
return {}

View File

@ -1,23 +0,0 @@
import streamlit as st
import requests
from agentpress.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE
API_BASE_URL = "http://localhost:8000"
def fetch_tools():
response = requests.get(f"{API_BASE_URL}/tools/")
if response.status_code == 200:
st.session_state.tools = response.json()
else:
st.error("Failed to fetch tools.")
def initialize_session_state():
if 'selected_thread' not in st.session_state:
st.session_state.selected_thread = None
if 'tools' not in st.session_state:
st.session_state.tools = []
if 'fetch_tools' not in st.session_state:
st.session_state.fetch_tools = fetch_tools
def fetch_data():
fetch_tools()