mirror of https://github.com/kortix-ai/suna.git
This commit is contained in:
parent
21e06f3162
commit
3f69ea9cc4
|
@ -1,10 +1,10 @@
|
|||
from .config import settings
|
||||
from .db import Database, Thread, ThreadRun
|
||||
from .db import Database, Thread
|
||||
from .llm import make_llm_api_call
|
||||
from .thread_manager import ThreadManager
|
||||
# from .working_memory_manager import WorkingMemory
|
||||
|
||||
__all__ = [
|
||||
'settings', 'Database', 'Thread', 'ThreadRun',
|
||||
'settings', 'Database', 'Thread',
|
||||
'make_llm_api_call', 'ThreadManager'
|
||||
] #'WorkingMemory'
|
|
@ -1,182 +0,0 @@
|
|||
from fastapi import FastAPI, HTTPException, Query, Path
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from agentpress.db import Database
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.config import Settings
|
||||
|
||||
app = FastAPI(
|
||||
title="Thread Manager API",
|
||||
description="API for managing and running threads with LLM integration",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
db = Database()
|
||||
manager = ThreadManager(db)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')")
|
||||
content: str = Field(..., description="The content of the message")
|
||||
|
||||
class RunThreadRequest(BaseModel):
|
||||
system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run")
|
||||
model_name: str = Field(..., description="The name of the LLM model to be used")
|
||||
temperature: float = Field(0.5, description="The sampling temperature for the LLM")
|
||||
max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate")
|
||||
tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run")
|
||||
tool_choice: str = Field("auto", description="Controls which tool is called by the model")
|
||||
additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message")
|
||||
additional_message: Optional[Dict[str, Any]] = Field(None, description="Additional message to be appended at the end of the conversation")
|
||||
hide_tool_msgs: bool = Field(False, description="Whether to hide tool messages in the conversation history")
|
||||
execute_tools_async: bool = Field(True, description="Whether to execute tools asynchronously")
|
||||
use_tool_parser: bool = Field(False, description="Whether to use the tool parser for handling tool calls")
|
||||
top_p: Optional[float] = Field(None, description="The nucleus sampling value")
|
||||
response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output")
|
||||
autonomous_iterations_amount: Optional[int] = Field(None, description="The number of autonomous iterations to perform")
|
||||
continue_instructions: Optional[str] = Field(None, description="Instructions for continuing the conversation in subsequent iterations")
|
||||
initializer: Optional[str] = Field(None, description="Name of the initializer function")
|
||||
pre_iteration: Optional[str] = Field(None, description="Name of the pre-iteration function")
|
||||
after_iteration: Optional[str] = Field(None, description="Name of the after-iteration function")
|
||||
finalizer: Optional[str] = Field(None, description="Name of the finalizer function")
|
||||
|
||||
@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread")
|
||||
async def create_thread():
|
||||
"""
|
||||
Create a new thread and return its ID.
|
||||
"""
|
||||
thread_id = await manager.create_thread()
|
||||
return {"thread_id": thread_id}
|
||||
|
||||
@app.get("/threads/", response_model=List[Dict[str, Any]], summary="Get all threads")
|
||||
async def get_threads():
|
||||
"""
|
||||
Retrieve a list of all threads.
|
||||
"""
|
||||
threads = await manager.get_threads()
|
||||
return [{"thread_id": thread.thread_id, "created_at": thread.created_at} for thread in threads]
|
||||
|
||||
@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread")
|
||||
async def add_message(thread_id: str, message: Message):
|
||||
"""
|
||||
Add a new message to the specified thread.
|
||||
"""
|
||||
await manager.add_message(thread_id, message.dict())
|
||||
return {"status": "success"}
|
||||
|
||||
@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread")
|
||||
async def list_messages(thread_id: str):
|
||||
"""
|
||||
Retrieve all messages from the specified thread.
|
||||
"""
|
||||
messages = await manager.list_messages(thread_id)
|
||||
return messages
|
||||
|
||||
@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread")
|
||||
async def run_thread(thread_id: str, request: RunThreadRequest):
|
||||
try:
|
||||
# Create a new ThreadRun object
|
||||
thread_run = await manager.create_thread_run(
|
||||
thread_id,
|
||||
model_name=request.model_name,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p,
|
||||
tool_choice=request.tool_choice,
|
||||
execute_tools_async=request.execute_tools_async,
|
||||
system_message=json.dumps(request.system_message),
|
||||
tools=json.dumps(request.tools),
|
||||
response_format=json.dumps(request.response_format),
|
||||
autonomous_iterations_amount=request.autonomous_iterations_amount,
|
||||
continue_instructions=request.continue_instructions
|
||||
)
|
||||
|
||||
# Run the thread with the created ThreadRun object
|
||||
result = await manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
thread_run=thread_run,
|
||||
initializer=get_function(request.initializer),
|
||||
pre_iteration=get_function(request.pre_iteration),
|
||||
after_iteration=get_function(request.after_iteration),
|
||||
finalizer=get_function(request.finalizer),
|
||||
**request.dict(exclude={'initializer', 'pre_iteration', 'after_iteration', 'finalizer'})
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def get_function(function_name: Optional[str]):
|
||||
if function_name is None:
|
||||
return None
|
||||
# Implement a way to get the function by name, e.g., from a predefined dictionary of functions
|
||||
# For now, we'll return None
|
||||
return None
|
||||
|
||||
@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools")
|
||||
async def get_tools():
|
||||
"""
|
||||
Retrieve a list of all available tools and their schemas.
|
||||
"""
|
||||
tools = tool_registry.get_all_tools()
|
||||
if not tools:
|
||||
print("No tools found in the registry") # Debug print
|
||||
return {
|
||||
name: {
|
||||
"name": name,
|
||||
"description": tool_info['schema']['function']['description'],
|
||||
"schema": tool_info['schema']
|
||||
}
|
||||
for name, tool_info in tools.items()
|
||||
}
|
||||
|
||||
@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run")
|
||||
async def get_run(
|
||||
thread_id: str = Path(..., description="The ID of the thread that was run"),
|
||||
run_id: str = Path(..., description="The ID of the run to retrieve")
|
||||
):
|
||||
run = await manager.get_run(thread_id, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return run
|
||||
|
||||
@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run")
|
||||
async def cancel_run(
|
||||
thread_id: str = Path(..., description="The ID of the thread to which this run belongs"),
|
||||
run_id: str = Path(..., description="The ID of the run to cancel")
|
||||
):
|
||||
"""
|
||||
Cancels a run that is in_progress.
|
||||
"""
|
||||
run = await manager.cancel_run(thread_id, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return run
|
||||
|
||||
@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs")
|
||||
async def list_runs(
|
||||
thread_id: str = Path(..., description="The ID of the thread the runs belong to"),
|
||||
limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned")
|
||||
):
|
||||
runs = await manager.list_runs(thread_id, limit)
|
||||
return runs
|
||||
|
||||
@app.post("/threads/{thread_id}/runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop a thread run")
|
||||
async def stop_thread_run(
|
||||
thread_id: str = Path(..., description="The ID of the thread"),
|
||||
run_id: str = Path(..., description="The ID of the run to stop")
|
||||
):
|
||||
"""
|
||||
Stops a thread run that is in progress.
|
||||
"""
|
||||
run = await manager.stop_thread_run(thread_id, run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="Run not found or already completed/stopped")
|
||||
return run
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
@ -17,43 +17,10 @@ class Thread(Base):
|
|||
messages = Column(Text)
|
||||
created_at = Column(Integer)
|
||||
|
||||
runs = relationship("ThreadRun", back_populates="thread")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.created_at = int(datetime.utcnow().timestamp())
|
||||
|
||||
class ThreadRun(Base):
|
||||
__tablename__ = "thread_runs"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
thread_id = Column(String, ForeignKey("threads.thread_id"))
|
||||
status = Column(String, default="queued")
|
||||
last_error = Column(String, nullable=True)
|
||||
created_at = Column(Integer)
|
||||
started_at = Column(Integer, nullable=True)
|
||||
cancelled_at = Column(Integer, nullable=True)
|
||||
failed_at = Column(Integer, nullable=True)
|
||||
completed_at = Column(Integer, nullable=True)
|
||||
model = Column(String)
|
||||
temperature = Column(Float)
|
||||
max_tokens = Column(Integer, nullable=True)
|
||||
top_p = Column(Float, nullable=True)
|
||||
tool_choice = Column(String)
|
||||
execute_tools_async = Column(Boolean)
|
||||
system_message = Column(JSON)
|
||||
tools = Column(JSON, nullable=True)
|
||||
usage = Column(JSON, nullable=True)
|
||||
response_format = Column(JSON, nullable=True)
|
||||
autonomous_iterations_amount = Column(Integer, nullable=True)
|
||||
continue_instructions = Column(String, nullable=True)
|
||||
iterations = Column(JSON, nullable=True)
|
||||
|
||||
thread = relationship("Thread", back_populates="runs")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.created_at = int(datetime.utcnow().timestamp())
|
||||
|
||||
# class MemoryModule(Base):
|
||||
# __tablename__ = 'memory_modules'
|
||||
|
|
|
@ -26,7 +26,7 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
|
||||
async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]:
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def attempt_api_call(api_call_func, max_attempts=3):
|
||||
|
@ -69,14 +69,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0
|
|||
api_call_params["max_tokens"] = max_tokens
|
||||
|
||||
if tools:
|
||||
if use_tool_parser:
|
||||
# Add tools as user messages
|
||||
tools_message = {"role": "user", "content": json.dumps(tools)}
|
||||
api_call_params["messages"].append(tools_message)
|
||||
else:
|
||||
# Use the existing method of adding tools
|
||||
api_call_params["tools"] = tools
|
||||
api_call_params["tool_choice"] = tool_choice
|
||||
# Use the existing method of adding tools
|
||||
api_call_params["tools"] = tools
|
||||
api_call_params["tool_choice"] = tool_choice
|
||||
|
||||
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
||||
api_call_params["extra_headers"] = {
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from typing import List, Dict, Any, Optional, Callable, Type
|
||||
from sqlalchemy import select
|
||||
from agentpress.db import Database, Thread, ThreadRun
|
||||
from agentpress.db import Database, Thread
|
||||
from agentpress.llm import make_llm_api_call
|
||||
from datetime import datetime, UTC
|
||||
from agentpress.tool import ToolResult
|
||||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
import uuid
|
||||
from tools.files_tool import FilesTool
|
||||
|
||||
class ThreadManager:
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
def __init__(self, db: Optional[Database] = None):
|
||||
self.db = db if db is not None else Database()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.run_config: Dict[str, Any] = {}
|
||||
self.current_iteration: int = 0
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None):
|
||||
"""
|
||||
Add a tool to the ThreadManager.
|
||||
If function_names is provided, only register those specific functions.
|
||||
If function_names is None, register all functions from the tool.
|
||||
"""
|
||||
self.tool_registry.register_tool(tool_class, function_names)
|
||||
|
||||
async def create_thread(self) -> int:
|
||||
async with self.db.get_async_session() as session:
|
||||
|
@ -25,7 +30,7 @@ class ThreadManager:
|
|||
)
|
||||
session.add(new_thread)
|
||||
await session.commit()
|
||||
await session.refresh(new_thread) # Ensure thread_id is populated
|
||||
await session.refresh(new_thread)
|
||||
return new_thread.thread_id
|
||||
|
||||
async def add_message(self, thread_id: int, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
||||
|
@ -38,9 +43,7 @@ class ThreadManager:
|
|||
try:
|
||||
messages = json.loads(thread.messages)
|
||||
|
||||
# If we're adding a user message, perform checks
|
||||
if message_data['role'] == 'user':
|
||||
# Find the last assistant message with tool calls
|
||||
last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
||||
|
||||
if last_assistant_index is not None:
|
||||
|
@ -50,12 +53,10 @@ class ThreadManager:
|
|||
if tool_call_count != tool_response_count:
|
||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||
|
||||
# Convert ToolResult objects to strings
|
||||
for key, value in message_data.items():
|
||||
if isinstance(value, ToolResult):
|
||||
message_data[key] = str(value)
|
||||
|
||||
# Process images if present
|
||||
if images:
|
||||
if isinstance(message_data['content'], str):
|
||||
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
||||
|
@ -81,50 +82,6 @@ class ThreadManager:
|
|||
logging.error(f"Failed to add message to thread {thread_id}: {e}")
|
||||
raise e
|
||||
|
||||
async def get_message(self, thread_id: int, message_index: int) -> Optional[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
if not thread:
|
||||
return None
|
||||
messages = json.loads(thread.messages)
|
||||
if message_index < len(messages):
|
||||
return messages[message_index]
|
||||
return None
|
||||
|
||||
async def modify_message(self, thread_id: int, message_index: int, new_message_data: Dict[str, Any]):
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
if not thread:
|
||||
raise ValueError(f"Thread with id {thread_id} not found")
|
||||
|
||||
try:
|
||||
messages = json.loads(thread.messages)
|
||||
if message_index < len(messages):
|
||||
messages[message_index] = new_message_data
|
||||
thread.messages = json.dumps(messages)
|
||||
await session.commit()
|
||||
else:
|
||||
raise ValueError(f"Message index {message_index} is out of range")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def remove_message(self, thread_id: int, message_index: int):
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
if not thread:
|
||||
raise ValueError(f"Thread with id {thread_id} not found")
|
||||
|
||||
try:
|
||||
messages = json.loads(thread.messages)
|
||||
if message_index < len(messages):
|
||||
del messages[message_index]
|
||||
thread.messages = json.dumps(messages)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def list_messages(self, thread_id: int, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
|
@ -138,7 +95,7 @@ class ThreadManager:
|
|||
return [msg]
|
||||
return []
|
||||
|
||||
filtered_messages = messages # Initialize filtered_messages with all messages
|
||||
filtered_messages = messages
|
||||
|
||||
if hide_tool_msgs:
|
||||
filtered_messages = [
|
||||
|
@ -164,7 +121,6 @@ class ThreadManager:
|
|||
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool']
|
||||
|
||||
if len(tool_calls) != len(tool_responses):
|
||||
# Create failed ToolResults for incomplete tool calls
|
||||
failed_tool_results = []
|
||||
for tool_call in tool_calls[len(tool_responses):]:
|
||||
failed_tool_result = {
|
||||
|
@ -175,7 +131,6 @@ class ThreadManager:
|
|||
}
|
||||
failed_tool_results.append(failed_tool_result)
|
||||
|
||||
# Insert failed tool results after the last assistant message
|
||||
assistant_index = messages.index(last_assistant_message)
|
||||
messages[assistant_index+1:assistant_index+1] = failed_tool_results
|
||||
|
||||
|
@ -188,236 +143,99 @@ class ThreadManager:
|
|||
return True
|
||||
return False
|
||||
|
||||
async def run_thread(self, settings: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
thread_run = ThreadRun(
|
||||
id=str(uuid.uuid4()),
|
||||
thread_id=settings['thread_id'],
|
||||
status="queued",
|
||||
model=settings['model_name'],
|
||||
temperature=settings.get('temperature', 0.7),
|
||||
max_tokens=settings.get('max_tokens'),
|
||||
top_p=settings.get('top_p'),
|
||||
tool_choice=settings.get('tool_choice', 'auto'),
|
||||
execute_tools_async=settings.get('execute_tools_async', True),
|
||||
system_message=json.dumps(settings['system_message']),
|
||||
tools=json.dumps(settings.get('tools')),
|
||||
response_format=json.dumps(settings.get('response_format')),
|
||||
autonomous_iterations_amount=settings.get('autonomous_iterations_amount', 1),
|
||||
continue_instructions=settings.get('continue_instructions')
|
||||
)
|
||||
|
||||
async with self.db.get_async_session() as session:
|
||||
session.add(thread_run)
|
||||
await session.commit()
|
||||
|
||||
thread_run.status = "in_progress"
|
||||
thread_run.started_at = int(datetime.now(UTC).timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
self.run_config = {k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')}
|
||||
self.run_config['iterations'] = []
|
||||
|
||||
if settings.get('initializer'):
|
||||
settings['initializer']()
|
||||
# Update thread_run with changes from run_config
|
||||
for key, value in self.run_config.items():
|
||||
setattr(thread_run, key, value)
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
full_tools = None
|
||||
if settings.get('tools'):
|
||||
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in settings['tools'] if self.tool_registry.get_tool(tool_name)]
|
||||
|
||||
self.current_iteration = 0
|
||||
for iteration in range(settings.get('autonomous_iterations_amount', 1)):
|
||||
self.current_iteration = iteration + 1
|
||||
|
||||
if await self.should_stop(settings['thread_id'], thread_run.id):
|
||||
thread_run.status = "stopped"
|
||||
thread_run.cancelled_at = int(datetime.now(UTC).timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
return {"status": "stopped", "message": "Thread run cancelled"}
|
||||
|
||||
if settings.get('pre_iteration'):
|
||||
settings['pre_iteration']()
|
||||
# Update thread_run with changes from run_config
|
||||
for key, value in self.run_config.items():
|
||||
setattr(thread_run, key, value)
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
if iteration > 0 and settings.get('continue_instructions'):
|
||||
await self.add_message(settings['thread_id'], {"role": "user", "content": settings['continue_instructions']})
|
||||
|
||||
messages = await self.list_messages(settings['thread_id'], hide_tool_msgs=settings.get('hide_tool_msgs', False))
|
||||
prepared_messages = [settings['system_message']] + messages
|
||||
|
||||
if settings.get('additional_message'):
|
||||
prepared_messages.append(settings['additional_message'])
|
||||
|
||||
response = await make_llm_api_call(
|
||||
prepared_messages,
|
||||
settings['model_name'],
|
||||
temperature=thread_run.temperature,
|
||||
max_tokens=thread_run.max_tokens,
|
||||
tools=full_tools,
|
||||
tool_choice=thread_run.tool_choice,
|
||||
stream=False,
|
||||
top_p=thread_run.top_p,
|
||||
response_format=json.loads(thread_run.response_format) if thread_run.response_format else None
|
||||
)
|
||||
|
||||
usage = response.usage if hasattr(response, 'usage') else None
|
||||
usage_dict = self.serialize_usage(usage) if usage else None
|
||||
thread_run.usage = usage_dict
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message['content']
|
||||
}
|
||||
if 'tool_calls' in response.choices[0].message:
|
||||
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
|
||||
|
||||
await self.add_message(settings['thread_id'], assistant_message)
|
||||
|
||||
if settings.get('tools') is None or settings.get('use_tool_parser', False):
|
||||
await self.handle_response_without_tools(settings['thread_id'], response, settings.get('use_tool_parser', False))
|
||||
else:
|
||||
await self.handle_response_with_tools(settings['thread_id'], response, settings.get('execute_tools_async', True))
|
||||
|
||||
self.run_config['iterations'].append({
|
||||
"iteration": self.current_iteration,
|
||||
"response": self.serialize_choice(response.choices[0]),
|
||||
"usage": usage_dict
|
||||
})
|
||||
|
||||
if settings.get('after_iteration'):
|
||||
settings['after_iteration']()
|
||||
# Update thread_run with changes from run_config
|
||||
for key, value in self.run_config.items():
|
||||
setattr(thread_run, key, value)
|
||||
|
||||
thread_run.iterations = json.dumps(self.run_config['iterations'])
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
thread_run.status = "completed"
|
||||
thread_run.completed_at = int(datetime.now(UTC).timestamp())
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')})
|
||||
|
||||
if settings.get('finalizer'):
|
||||
settings['finalizer']()
|
||||
# Update thread_run with final changes from run_config
|
||||
for key, value in self.run_config.items():
|
||||
setattr(thread_run, key, value)
|
||||
await self.update_thread_run(thread_run)
|
||||
|
||||
return {
|
||||
"id": thread_run.id,
|
||||
"status": thread_run.status,
|
||||
"iterations": self.run_config['iterations'],
|
||||
"total_iterations": len(self.run_config['iterations']),
|
||||
"usage": thread_run.usage,
|
||||
"model": settings['model_name'],
|
||||
"object": "chat.completion",
|
||||
"created": int(datetime.now(UTC).timestamp())
|
||||
}
|
||||
except Exception as e:
|
||||
thread_run.status = "failed"
|
||||
thread_run.failed_at = int(datetime.now(UTC).timestamp())
|
||||
thread_run.last_error = str(e)
|
||||
await self.update_thread_run(thread_run)
|
||||
self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')})
|
||||
if settings.get('finalizer'):
|
||||
settings['finalizer']()
|
||||
raise
|
||||
|
||||
async def update_thread_run(self, thread_run: ThreadRun):
|
||||
async with self.db.get_async_session() as session:
|
||||
session.add(thread_run)
|
||||
await session.commit()
|
||||
await session.refresh(thread_run)
|
||||
|
||||
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
|
||||
response_content = response.choices[0].message['content']
|
||||
async def run_thread(self, thread_id: int, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = True) -> Dict[str, Any]:
|
||||
|
||||
if use_tool_parser:
|
||||
await self.handle_tool_parser_response(thread_id, response_content)
|
||||
else:
|
||||
# The message has already been added in the run_thread method, so we don't need to add it again here
|
||||
pass
|
||||
messages = await self.list_messages(thread_id)
|
||||
prepared_messages = [system_message] + messages
|
||||
|
||||
if additional_message:
|
||||
prepared_messages.append(additional_message)
|
||||
|
||||
tools = self.tool_registry.get_all_tool_schemas() if use_tools else None
|
||||
|
||||
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
|
||||
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
|
||||
if tool_call_match:
|
||||
try:
|
||||
tool_call_json = json.loads(tool_call_match.group())
|
||||
tool_calls = tool_call_json.get('function_calls', [])
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": response_content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call['name'],
|
||||
"arguments": json.dumps(call['arguments'])
|
||||
}
|
||||
} for i, call in enumerate(tool_calls)
|
||||
]
|
||||
try:
|
||||
llm_response = await make_llm_api_call(
|
||||
prepared_messages,
|
||||
model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice if use_tools else None,
|
||||
stream=False
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in API call: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"run_thread_params": {
|
||||
"thread_id": thread_id,
|
||||
"system_message": system_message,
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tool_choice": tool_choice,
|
||||
"additional_message": additional_message,
|
||||
"execute_tools_async": execute_tools_async,
|
||||
"execute_model_tool_calls": execute_model_tool_calls,
|
||||
"use_tools": use_tools
|
||||
}
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
}
|
||||
|
||||
available_functions = self.get_available_functions()
|
||||
|
||||
tool_results = await self.execute_tools(assistant_message['tool_calls'], available_functions, thread_id, execute_tools_async=True)
|
||||
|
||||
await self.process_tool_results(thread_id, tool_results)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logging.error("Failed to parse tool call JSON from response")
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||
if use_tools and execute_model_tool_calls:
|
||||
await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async)
|
||||
else:
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||
await self.handle_response_without_tools(thread_id, llm_response)
|
||||
|
||||
return {
|
||||
"llm_response": llm_response,
|
||||
"run_thread_params": {
|
||||
"thread_id": thread_id,
|
||||
"system_message": system_message,
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tool_choice": tool_choice,
|
||||
"additional_message": additional_message,
|
||||
"execute_tools_async": execute_tools_async,
|
||||
"execute_model_tool_calls": execute_model_tool_calls,
|
||||
"use_tools": use_tools
|
||||
}
|
||||
}
|
||||
|
||||
async def handle_response_without_tools(self, thread_id: int, response: Any):
|
||||
response_content = response.choices[0].message['content']
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||
|
||||
async def handle_response_with_tools(self, thread_id: int, response: Any, execute_tools_async: bool):
|
||||
try:
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.get('tool_calls', [])
|
||||
|
||||
# The assistant message has already been added in the run_thread method
|
||||
|
||||
assistant_message = self.create_assistant_message_with_tools(response_message)
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
|
||||
available_functions = self.get_available_functions()
|
||||
|
||||
if await self.should_stop(thread_id, thread_id):
|
||||
return {"status": "stopped", "message": "Session cancelled"}
|
||||
|
||||
if tool_calls:
|
||||
if execute_tools_async:
|
||||
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
||||
else:
|
||||
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
||||
|
||||
# Add tool results to messages
|
||||
for result in tool_results:
|
||||
await self.add_message(thread_id, result)
|
||||
|
||||
if await self.should_stop(thread_id, thread_id):
|
||||
return {"status": "stopped", "message": "Session cancelled after tool execution"}
|
||||
|
||||
except AttributeError as e:
|
||||
logging.error(f"AttributeError: {e}")
|
||||
# No need to add the message here as it's already been added in the run_thread method
|
||||
response_content = response.choices[0].message['content']
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""})
|
||||
|
||||
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": response_message.get('content') or "",
|
||||
}
|
||||
|
||||
tool_calls = response_message.get('tool_calls')
|
||||
if tool_calls:
|
||||
message["tool_calls"] = [
|
||||
|
@ -441,33 +259,12 @@ class ThreadManager:
|
|||
available_functions[func_name] = getattr(tool_instance, func_name)
|
||||
return available_functions
|
||||
|
||||
async def execute_tools(self, tool_calls: List[Any], available_functions: Dict[str, Callable], thread_id: int, execute_tools_async: bool) -> List[Dict[str, Any]]:
|
||||
if execute_tools_async:
|
||||
return await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
||||
else:
|
||||
return await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
||||
|
||||
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
|
||||
async def execute_single_tool(tool_call):
|
||||
if await self.should_stop(thread_id, thread_id):
|
||||
return {"status": "stopped", "message": "Session cancelled"}
|
||||
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f"Error parsing arguments for {function_name}: {str(e)}"
|
||||
logging.error(error_message)
|
||||
logging.error(f"Problematic JSON: {tool_call.function.arguments}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_message)),
|
||||
}
|
||||
|
||||
function_to_call = available_functions.get(function_name)
|
||||
if function_to_call:
|
||||
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
||||
|
@ -481,9 +278,6 @@ class ThreadManager:
|
|||
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
|
||||
tool_results = []
|
||||
for tool_call in tool_calls:
|
||||
if await self.should_stop(thread_id, thread_id):
|
||||
return [{"status": "stopped", "message": "Session cancelled"}]
|
||||
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
tool_call_id = tool_call.id
|
||||
|
@ -498,10 +292,6 @@ class ThreadManager:
|
|||
|
||||
return tool_results
|
||||
|
||||
async def process_tool_results(self, thread_id: int, tool_results: List[Dict[str, Any]]):
|
||||
for result in tool_results:
|
||||
await self.add_message(thread_id, result['tool_message'])
|
||||
|
||||
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
|
||||
try:
|
||||
function_response = await function_to_call(**function_args)
|
||||
|
@ -516,297 +306,60 @@ class ThreadManager:
|
|||
"content": str(function_response),
|
||||
}
|
||||
|
||||
async def should_stop(self, thread_id: str, run_id: str) -> bool:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.status in ["stopped", "cancelled", "queued"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def stop_thread_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id and run.status == "in_progress":
|
||||
run.status = "stopping"
|
||||
await session.commit()
|
||||
return self.serialize_thread_run(run)
|
||||
return None
|
||||
|
||||
async def save_thread_run(self, thread_id: str):
|
||||
async with self.db.get_async_session() as session:
|
||||
thread = await session.get(Thread, thread_id)
|
||||
if not thread:
|
||||
raise ValueError(f"Thread with id {thread_id} not found")
|
||||
|
||||
messages = json.loads(thread.messages)
|
||||
creation_date = datetime.now().isoformat()
|
||||
|
||||
# Get the latest ThreadRun for this thread
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
latest_thread_run = result.scalar_one_or_none()
|
||||
|
||||
if latest_thread_run:
|
||||
# Update the existing ThreadRun
|
||||
latest_thread_run.messages = json.dumps(messages)
|
||||
latest_thread_run.last_updated_date = creation_date
|
||||
await session.commit()
|
||||
else:
|
||||
# Create a new ThreadRun if none exists
|
||||
new_thread_run = ThreadRun(
|
||||
thread_id=thread_id,
|
||||
messages=json.dumps(messages),
|
||||
creation_date=creation_date,
|
||||
status='completed'
|
||||
)
|
||||
session.add(new_thread_run)
|
||||
await session.commit()
|
||||
|
||||
async def get_thread(self, thread_id: int) -> Optional[Thread]:
|
||||
async with self.db.get_async_session() as session:
|
||||
return await session.get(Thread, thread_id)
|
||||
|
||||
async def update_thread_run_with_error(self, thread_id: int, error_message: str):
|
||||
async with self.db.get_async_session() as session:
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
thread_run = result.scalar_one_or_none()
|
||||
if thread_run:
|
||||
thread_run.status = 'error'
|
||||
thread_run.error_message = error_message # Store the full error message
|
||||
await session.commit()
|
||||
|
||||
async def get_threads(self) -> List[Thread]:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_latest_thread_run(self, thread_id: str):
|
||||
async with self.db.get_async_session() as session:
|
||||
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
latest_run = result.scalar_one_or_none()
|
||||
if latest_run:
|
||||
return {
|
||||
"id": latest_run.id,
|
||||
"status": latest_run.status,
|
||||
"error_message": latest_run.last_error,
|
||||
"created_at": latest_run.created_at,
|
||||
"started_at": latest_run.started_at,
|
||||
"completed_at": latest_run.completed_at,
|
||||
"cancelled_at": latest_run.cancelled_at,
|
||||
"failed_at": latest_run.failed_at,
|
||||
"model": latest_run.model,
|
||||
"usage": latest_run.usage
|
||||
}
|
||||
return None
|
||||
|
||||
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id:
|
||||
return {
|
||||
"id": run.id,
|
||||
"thread_id": run.thread_id,
|
||||
"status": run.status,
|
||||
"created_at": run.created_at,
|
||||
"started_at": run.started_at,
|
||||
"completed_at": run.completed_at,
|
||||
"cancelled_at": run.cancelled_at,
|
||||
"failed_at": run.failed_at,
|
||||
"model": run.model,
|
||||
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||
"tools": json.loads(run.tools) if run.tools else None,
|
||||
"usage": run.usage,
|
||||
"temperature": run.temperature,
|
||||
"top_p": run.top_p,
|
||||
"max_tokens": run.max_tokens,
|
||||
"tool_choice": run.tool_choice,
|
||||
"execute_tools_async": run.execute_tools_async,
|
||||
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||
"last_error": run.last_error
|
||||
}
|
||||
return None
|
||||
|
||||
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id and run.status == "in_progress":
|
||||
run.status = "cancelled"
|
||||
run.cancelled_at = int(datetime.now(UTC).timestamp())
|
||||
await session.commit()
|
||||
return await self.get_run(thread_id, run_id)
|
||||
return None
|
||||
|
||||
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
|
||||
async with self.db.get_async_session() as session:
|
||||
thread_runs_stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
|
||||
thread_runs_result = await session.execute(thread_runs_stmt)
|
||||
thread_runs = thread_runs_result.scalars().all()
|
||||
return [self.serialize_thread_run(run) for run in thread_runs]
|
||||
|
||||
async def create_thread_run(self, thread_id: str, **kwargs) -> ThreadRun:
|
||||
run_id = str(uuid.uuid4())
|
||||
thread_run = ThreadRun(
|
||||
id=run_id,
|
||||
thread_id=thread_id,
|
||||
status="queued",
|
||||
model=kwargs.get('model_name'),
|
||||
temperature=kwargs.get('temperature'),
|
||||
max_tokens=kwargs.get('max_tokens'),
|
||||
top_p=kwargs.get('top_p'),
|
||||
tool_choice=kwargs.get('tool_choice', "auto"),
|
||||
execute_tools_async=kwargs.get('execute_tools_async', True),
|
||||
system_message=json.dumps(kwargs.get('system_message')),
|
||||
tools=json.dumps(kwargs.get('tools')),
|
||||
response_format=json.dumps(kwargs.get('response_format')),
|
||||
autonomous_iterations_amount=kwargs.get('autonomous_iterations_amount'),
|
||||
continue_instructions=kwargs.get('continue_instructions')
|
||||
)
|
||||
async with self.db.get_async_session() as session:
|
||||
session.add(thread_run)
|
||||
await session.commit()
|
||||
return thread_run
|
||||
|
||||
async def get_thread_run_count(self, thread_id: str) -> int:
|
||||
async with self.db.get_async_session() as session:
|
||||
result = await session.execute(select(ThreadRun).filter_by(thread_id=thread_id))
|
||||
return len(result.all())
|
||||
|
||||
async def get_thread_run_status(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
||||
async with self.db.get_async_session() as session:
|
||||
run = await session.get(ThreadRun, run_id)
|
||||
if run and run.thread_id == thread_id:
|
||||
return self.serialize_thread_run(run)
|
||||
return None
|
||||
|
||||
def serialize_thread_run(self, run: ThreadRun) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"thread_id": run.thread_id,
|
||||
"status": run.status,
|
||||
"created_at": run.created_at,
|
||||
"started_at": run.started_at,
|
||||
"completed_at": run.completed_at,
|
||||
"cancelled_at": run.cancelled_at,
|
||||
"failed_at": run.failed_at,
|
||||
"model": run.model,
|
||||
"temperature": run.temperature,
|
||||
"max_tokens": run.max_tokens,
|
||||
"top_p": run.top_p,
|
||||
"tool_choice": run.tool_choice,
|
||||
"execute_tools_async": run.execute_tools_async,
|
||||
"system_message": json.loads(run.system_message) if run.system_message else None,
|
||||
"tools": json.loads(run.tools) if run.tools else None,
|
||||
"usage": run.usage,
|
||||
"response_format": json.loads(run.response_format) if run.response_format else None,
|
||||
"last_error": run.last_error,
|
||||
"autonomous_iterations_amount": run.autonomous_iterations_amount,
|
||||
"continue_instructions": run.continue_instructions,
|
||||
"iterations": json.loads(run.iterations) if run.iterations else None
|
||||
}
|
||||
|
||||
def serialize_usage(self, usage):
|
||||
return {
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
|
||||
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
|
||||
}
|
||||
|
||||
def serialize_completion_tokens_details(self, details):
|
||||
return {
|
||||
"audio_tokens": details.audio_tokens,
|
||||
"reasoning_tokens": details.reasoning_tokens
|
||||
}
|
||||
|
||||
def serialize_prompt_tokens_details(self, details):
|
||||
return {
|
||||
"audio_tokens": details.audio_tokens,
|
||||
"cached_tokens": details.cached_tokens
|
||||
}
|
||||
|
||||
def serialize_choice(self, choice):
|
||||
return {
|
||||
"finish_reason": choice.finish_reason,
|
||||
"index": choice.index,
|
||||
"message": self.serialize_message(choice.message)
|
||||
}
|
||||
|
||||
def serialize_message(self, message):
|
||||
return {
|
||||
"content": message.content,
|
||||
"role": message.role,
|
||||
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
|
||||
}
|
||||
|
||||
def serialize_tool_call(self, tool_call):
|
||||
return {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
from agentpress.db import Database
|
||||
from tools.files_tool import FilesTool
|
||||
|
||||
async def main():
|
||||
db = Database()
|
||||
manager = ThreadManager(db)
|
||||
|
||||
manager = ThreadManager()
|
||||
|
||||
manager.add_tool(FilesTool, ['create_file', 'read_file'])
|
||||
|
||||
thread_id = await manager.create_thread()
|
||||
await manager.add_message(thread_id, {"role": "user", "content": "Let's have a conversation about artificial intelligence and create a file summarizing our discussion."})
|
||||
|
||||
await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"})
|
||||
|
||||
system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."}
|
||||
model_name = "gpt-4o"
|
||||
|
||||
system_message = {"role": "system", "content": "You are an AI expert engaging in a conversation about artificial intelligence. You can also create and manage files."}
|
||||
|
||||
files_tool = FilesTool()
|
||||
tool_schemas = files_tool.get_schemas()
|
||||
# Test with tools
|
||||
response_with_tools = await manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_message=system_message,
|
||||
model_name=model_name,
|
||||
temperature=0.7,
|
||||
max_tokens=150,
|
||||
tool_choice="auto",
|
||||
additional_message=None,
|
||||
execute_tools_async=True,
|
||||
execute_model_tool_calls=True,
|
||||
use_tools=True
|
||||
)
|
||||
|
||||
def initializer():
|
||||
print("Initializing thread run...")
|
||||
manager.run_config['temperature'] = 0.8
|
||||
print("Response with tools:", response_with_tools)
|
||||
|
||||
def pre_iteration():
|
||||
print(f"Preparing iteration {manager.current_iteration}...")
|
||||
manager.run_config['max_tokens'] = 200 if manager.current_iteration > 3 else 150
|
||||
# Test without tools
|
||||
response_without_tools = await manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_message=system_message,
|
||||
model_name=model_name,
|
||||
temperature=0.7,
|
||||
max_tokens=150,
|
||||
additional_message={"role": "user", "content": "What's the capital of France?"},
|
||||
use_tools=False
|
||||
)
|
||||
|
||||
def after_iteration():
|
||||
print(f"Completed iteration {manager.current_iteration}. Status: {manager.run_config['status']}")
|
||||
manager.run_config['continue_instructions'] = "Let's focus more on AI ethics in the next iteration and update our summary file."
|
||||
|
||||
def finalizer():
|
||||
print(f"Thread run finished with status: {manager.run_config['status']}")
|
||||
print(f"Final configuration: {manager.run_config}")
|
||||
|
||||
settings = {
|
||||
"thread_id": thread_id,
|
||||
"system_message": system_message,
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150,
|
||||
"autonomous_iterations_amount": 3,
|
||||
"continue_instructions": "Continue the conversation about AI, introducing new aspects or asking thought-provoking questions. Don't forget to update our summary file.",
|
||||
"initializer": initializer,
|
||||
"pre_iteration": pre_iteration,
|
||||
"after_iteration": after_iteration,
|
||||
"finalizer": finalizer,
|
||||
"tools": list(tool_schemas.keys()),
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
|
||||
response = await manager.run_thread(settings)
|
||||
|
||||
print(f"Thread run response: {response}")
|
||||
print("Response without tools:", response_without_tools)
|
||||
|
||||
# List messages in the thread
|
||||
messages = await manager.list_messages(thread_id)
|
||||
print("\nFinal conversation:")
|
||||
print("\nMessages in the thread:")
|
||||
for msg in messages:
|
||||
print(f"{msg['role'].capitalize()}: {msg['content']}")
|
||||
|
||||
asyncio.run(main())
|
||||
# Run the async main function
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -1,23 +1,100 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
"""
|
||||
This module provides the foundation for creating and managing tools in the AgentPress system.
|
||||
|
||||
The tool system allows for easy creation of function-like tools that can be used by AI models.
|
||||
It provides a way to define OpenAPI schemas for these tools, which can then be used to generate
|
||||
appropriate function calls in the AI model's context.
|
||||
|
||||
Key components:
|
||||
- ToolResult: A dataclass representing the result of a tool execution.
|
||||
- Tool: An abstract base class that all tools should inherit from.
|
||||
- tool_schema: A decorator for easily defining OpenAPI schemas for tool methods.
|
||||
|
||||
Usage:
|
||||
1. Create a new tool by subclassing Tool.
|
||||
2. Define methods in your tool class and decorate them with @tool_schema.
|
||||
3. The Tool class will automatically register these schemas.
|
||||
4. Use the tool in your ThreadManager by adding it with add_tool method.
|
||||
|
||||
Example:
|
||||
class MyTool(Tool):
|
||||
@tool_schema({
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
})
|
||||
async def add(self, a: float, b: float) -> ToolResult:
|
||||
return self.success_response(f"The sum is {a + b}")
|
||||
|
||||
# In your thread manager:
|
||||
manager.add_tool(MyTool)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
import json
|
||||
import inspect
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool execution.
|
||||
|
||||
Attributes:
|
||||
success (bool): Whether the tool execution was successful.
|
||||
output (str): The output of the tool execution.
|
||||
"""
|
||||
success: bool
|
||||
output: str
|
||||
|
||||
class Tool(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
Abstract base class for all tools.
|
||||
|
||||
This class provides the basic structure and functionality for tools.
|
||||
Subclasses should implement specific tool methods decorated with @tool_schema.
|
||||
|
||||
Methods:
|
||||
get_schemas(): Returns a dictionary of all registered tool schemas.
|
||||
success_response(data): Creates a successful ToolResult.
|
||||
fail_response(msg): Creates a failed ToolResult.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._schemas = {}
|
||||
self._register_schemas()
|
||||
|
||||
def _register_schemas(self):
|
||||
"""
|
||||
Automatically registers schemas for all methods decorated with @tool_schema.
|
||||
"""
|
||||
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
||||
if hasattr(method, 'schema'):
|
||||
self._schemas[name] = method.schema
|
||||
|
||||
@abstractmethod
|
||||
def get_schemas(self) -> Dict[str, Dict[str, Any]]:
|
||||
pass
|
||||
"""
|
||||
Returns a dictionary of all registered tool schemas, formatted for use with AI models.
|
||||
"""
|
||||
return self._schemas
|
||||
|
||||
def success_response(self, data: Dict[str, Any] | str) -> ToolResult:
|
||||
"""
|
||||
Creates a successful ToolResult with the given data.
|
||||
|
||||
Args:
|
||||
data: The data to include in the success response.
|
||||
|
||||
Returns:
|
||||
A ToolResult indicating success.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
|
@ -25,10 +102,47 @@ class Tool(ABC):
|
|||
return ToolResult(success=True, output=text)
|
||||
|
||||
def fail_response(self, msg: str) -> ToolResult:
|
||||
"""
|
||||
Creates a failed ToolResult with the given error message.
|
||||
|
||||
Args:
|
||||
msg: The error message to include in the failure response.
|
||||
|
||||
Returns:
|
||||
A ToolResult indicating failure.
|
||||
"""
|
||||
return ToolResult(success=False, output=msg)
|
||||
|
||||
def format_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
def tool_schema(schema: Dict[str, Any]):
|
||||
"""
|
||||
A decorator for easily defining OpenAPI schemas for tool methods.
|
||||
|
||||
This decorator allows you to define the schema for a tool method inline with the method definition.
|
||||
It attaches the provided schema directly to the method.
|
||||
|
||||
Args:
|
||||
schema (Dict[str, Any]): An OpenAPI schema describing the tool.
|
||||
|
||||
Example:
|
||||
@tool_schema({
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
})
|
||||
async def add(self, a: float, b: float) -> ToolResult:
|
||||
return self.success_response(f"The sum is {a + b}")
|
||||
"""
|
||||
def decorator(func):
|
||||
func.schema = {
|
||||
"type": "function",
|
||||
"function": schema
|
||||
}
|
||||
}
|
||||
return func
|
||||
return decorator
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Type, Any, Optional
|
||||
from typing import Dict, Type, Any, List, Optional
|
||||
from agentpress.tool import Tool
|
||||
from agentpress.config import settings
|
||||
import importlib.util
|
||||
|
@ -8,39 +8,34 @@ import inspect
|
|||
class ToolRegistry:
|
||||
def __init__(self):
|
||||
self.tools: Dict[str, Dict[str, Any]] = {}
|
||||
self.register_all_tools()
|
||||
|
||||
def register_tool(self, tool_cls: Type[Tool]):
|
||||
def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None):
|
||||
tool_instance = tool_cls()
|
||||
schemas = tool_instance.get_schemas()
|
||||
for func_name, schema in schemas.items():
|
||||
self.tools[func_name] = {
|
||||
"instance": tool_instance,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
if function_names is None:
|
||||
# Register all functions
|
||||
for func_name, schema in schemas.items():
|
||||
self.tools[func_name] = {
|
||||
"instance": tool_instance,
|
||||
"schema": schema
|
||||
}
|
||||
else:
|
||||
# Register only specified functions
|
||||
for func_name in function_names:
|
||||
if func_name in schemas:
|
||||
self.tools[func_name] = {
|
||||
"instance": tool_instance,
|
||||
"schema": schemas[func_name]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}")
|
||||
|
||||
def register_all_tools(self):
|
||||
tools_dir = settings.tools_dir
|
||||
for file in os.listdir(tools_dir):
|
||||
if file.endswith('.py') and file not in ['__init__.py', 'tool.py', 'tool_registry.py']:
|
||||
module_path = os.path.join(tools_dir, file)
|
||||
module_name = os.path.splitext(file)[0]
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, Tool) and obj != Tool:
|
||||
print(f"Registering tool: {name}") # Debug print
|
||||
self.register_tool(obj)
|
||||
except Exception as e:
|
||||
print(f"Error importing {module_path}: {e}") # Debug print
|
||||
|
||||
print(f"Registered tools: {list(self.tools.keys())}") # Debug print
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
return self.tools.get(tool_name)
|
||||
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
||||
return self.tools.get(tool_name, {})
|
||||
|
||||
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
|
||||
return self.tools
|
||||
return self.tools
|
||||
|
||||
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [tool_info['schema'] for tool_info in self.tools.values()]
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .main import main
|
||||
from .tool_display import display_tools
|
||||
|
||||
__all__ = ['main', 'display_tools']
|
|
@ -1,44 +0,0 @@
|
|||
import streamlit as st
|
||||
from agentpress.ui.thread_management import display_thread_management
|
||||
from agentpress.ui.message_display import display_messages_and_runner
|
||||
from agentpress.ui.thread_runner import fetch_thread_runs, display_runs
|
||||
from agentpress.ui.tool_display import display_tools
|
||||
from agentpress.ui.utils import initialize_session_state, fetch_data, API_BASE_URL
|
||||
|
||||
def main():
|
||||
initialize_session_state()
|
||||
fetch_data()
|
||||
|
||||
st.set_page_config(page_title="AI Assistant Management System", layout="wide")
|
||||
|
||||
st.sidebar.title("Navigation")
|
||||
mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"])
|
||||
|
||||
st.title("AI Assistant Management System")
|
||||
|
||||
if mode == "Tools":
|
||||
display_tools()
|
||||
else: # Thread Management
|
||||
display_thread_management_content()
|
||||
|
||||
def display_thread_management_content():
|
||||
col1, col2 = st.columns([1, 3])
|
||||
|
||||
with col1:
|
||||
display_thread_management()
|
||||
if st.session_state.selected_thread:
|
||||
display_thread_runner(st.session_state.selected_thread)
|
||||
|
||||
with col2:
|
||||
if st.session_state.selected_thread:
|
||||
display_messages_and_runner(st.session_state.selected_thread)
|
||||
|
||||
def display_thread_runner(thread_id):
|
||||
|
||||
limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20)
|
||||
if st.button("Fetch Runs"):
|
||||
runs = fetch_thread_runs(thread_id, limit)
|
||||
# display_runs(runs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,106 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from agentpress.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||
from agentpress.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content
|
||||
|
||||
def display_messages_and_runner(thread_id):
|
||||
st.subheader(f"Messages for Thread: {thread_id}")
|
||||
|
||||
messages_container = st.empty()
|
||||
|
||||
def update_messages():
|
||||
messages = fetch_messages(thread_id)
|
||||
with messages_container.container():
|
||||
display_message_list(messages)
|
||||
|
||||
update_messages()
|
||||
|
||||
display_add_message_form(thread_id, update_messages)
|
||||
display_run_thread_form(thread_id, update_messages)
|
||||
|
||||
def fetch_messages(thread_id):
|
||||
messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/")
|
||||
if messages_response.status_code == 200:
|
||||
return messages_response.json()
|
||||
else:
|
||||
st.error("Failed to fetch messages.")
|
||||
return []
|
||||
|
||||
def display_message_list(messages):
|
||||
for msg in messages:
|
||||
with st.chat_message(msg['role']):
|
||||
st.write(msg['content'])
|
||||
|
||||
def display_add_message_form(thread_id, update_callback):
|
||||
st.write("### Add a New Message")
|
||||
with st.form(key="add_message_form"):
|
||||
role = st.selectbox("Role", ["user", "assistant"], key="add_role")
|
||||
content = st.text_area("Content", key="add_content")
|
||||
submitted = st.form_submit_button("Add Message")
|
||||
if submitted:
|
||||
add_message(thread_id, role, content)
|
||||
update_callback()
|
||||
|
||||
def display_run_thread_form(thread_id, update_callback):
|
||||
st.write("### Run Thread")
|
||||
with st.form(key="run_thread_form"):
|
||||
model_name = st.selectbox("Model", AI_MODELS, key="model_name")
|
||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature")
|
||||
max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens")
|
||||
system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100)
|
||||
additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100)
|
||||
|
||||
tool_options = list(st.session_state.tools.keys())
|
||||
selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools")
|
||||
|
||||
autonomous_iterations_amount = st.number_input("Autonomous Iterations", min_value=1, max_value=10, value=1, key="autonomous_iterations_amount")
|
||||
continue_instructions = st.text_area("Continue Instructions", key="continue_instructions", height=100, help="Instructions for continuing the conversation in subsequent iterations")
|
||||
|
||||
initializer = st.text_input("Initializer Function", key="initializer")
|
||||
pre_iteration = st.text_input("Pre-iteration Function", key="pre_iteration")
|
||||
after_iteration = st.text_input("After-iteration Function", key="after_iteration")
|
||||
finalizer = st.text_input("Finalizer Function", key="finalizer")
|
||||
|
||||
submitted = st.form_submit_button("Run Thread")
|
||||
if submitted:
|
||||
run_thread_data = prepare_run_thread_data(
|
||||
model_name, temperature, max_tokens, system_message,
|
||||
additional_system_message, selected_tools,
|
||||
autonomous_iterations_amount, continue_instructions,
|
||||
initializer, pre_iteration, after_iteration, finalizer
|
||||
)
|
||||
response = run_thread(thread_id, run_thread_data)
|
||||
if response:
|
||||
st.subheader("Thread Run Response")
|
||||
st.json(response)
|
||||
|
||||
# Update messages
|
||||
update_callback()
|
||||
|
||||
def add_message(thread_id, role, content):
|
||||
message_data = {"role": role, "content": content}
|
||||
add_msg_response = requests.post(
|
||||
f"{API_BASE_URL}/threads/{thread_id}/messages/",
|
||||
json=message_data
|
||||
)
|
||||
if add_msg_response.status_code == 200:
|
||||
st.success("Message added successfully.")
|
||||
else:
|
||||
st.error("Failed to add message.")
|
||||
|
||||
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions, initializer, pre_iteration, after_iteration, finalizer):
|
||||
return {
|
||||
"system_message": {"role": "system", "content": system_message},
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": selected_tools,
|
||||
"additional_system_message": additional_system_message,
|
||||
"tool_choice": "auto",
|
||||
"autonomous_iterations_amount": autonomous_iterations_amount,
|
||||
"continue_instructions": continue_instructions,
|
||||
"initializer": initializer,
|
||||
"pre_iteration": pre_iteration,
|
||||
"after_iteration": after_iteration,
|
||||
"finalizer": finalizer
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
import streamlit as st
|
||||
import time
|
||||
import requests
|
||||
from agentpress.ui.utils import API_BASE_URL
|
||||
|
||||
def get_run_status(thread_id, run_id, is_agent_run):
|
||||
endpoint = f"agent_runs" if is_agent_run else f"runs"
|
||||
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/{endpoint}/{run_id}/status")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return None
|
||||
|
||||
def real_time_status_update(thread_id, run_id, is_agent_run):
|
||||
status_placeholder = st.empty()
|
||||
while True:
|
||||
status = get_run_status(thread_id, run_id, is_agent_run)
|
||||
if status:
|
||||
status_placeholder.write(f"Current status: {status['status']}")
|
||||
if status['status'] in ['completed', 'failed', 'cancelled']:
|
||||
break
|
||||
time.sleep(1)
|
||||
return status
|
|
@ -1,87 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from agentpress.ui.utils import API_BASE_URL
|
||||
from datetime import datetime
|
||||
from agentpress.ui.thread_runner import stop_thread_run, stop_agent_run, get_thread_run_status, get_agent_run_status
|
||||
|
||||
def display_thread_management():
|
||||
st.subheader("Thread Management")
|
||||
|
||||
if st.button("➕ Create New Thread", key="create_thread_button"):
|
||||
create_new_thread()
|
||||
|
||||
display_thread_selector()
|
||||
|
||||
if st.session_state.selected_thread:
|
||||
display_run_history(st.session_state.selected_thread)
|
||||
|
||||
def create_new_thread():
|
||||
response = requests.post(f"{API_BASE_URL}/threads/")
|
||||
if response.status_code == 200:
|
||||
thread_id = response.json()['thread_id']
|
||||
st.session_state.selected_thread = thread_id
|
||||
st.success(f"New thread created with ID: {thread_id}")
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Failed to create a new thread.")
|
||||
|
||||
def display_thread_selector():
|
||||
threads_response = requests.get(f"{API_BASE_URL}/threads/")
|
||||
if threads_response.status_code == 200:
|
||||
threads = threads_response.json()
|
||||
|
||||
# Sort threads by created_at timestamp (newest first)
|
||||
sorted_threads = sorted(threads, key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
thread_options = [f"{thread['thread_id']} - Created: {format_timestamp(thread['created_at'])}" for thread in sorted_threads]
|
||||
|
||||
if st.session_state.selected_thread is None and sorted_threads:
|
||||
st.session_state.selected_thread = sorted_threads[0]['thread_id']
|
||||
|
||||
selected_thread = st.selectbox(
|
||||
"🔍 Select Thread",
|
||||
thread_options,
|
||||
key="thread_select",
|
||||
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
|
||||
)
|
||||
|
||||
if selected_thread:
|
||||
st.session_state.selected_thread = selected_thread.split(' - ')[0]
|
||||
else:
|
||||
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")
|
||||
|
||||
def display_run_history(thread_id):
|
||||
st.subheader("Run History")
|
||||
|
||||
# Fetch thread runs
|
||||
thread_runs = fetch_thread_runs(thread_id)
|
||||
|
||||
# Display thread runs
|
||||
st.write("### Thread Runs")
|
||||
for run in thread_runs:
|
||||
with st.expander(f"Run {run['id']} - Status: {run['status']}"):
|
||||
st.write(f"Created At: {format_timestamp(run['created_at'])}")
|
||||
st.write(f"Status: {run['status']}")
|
||||
st.write(f"Iterations: {len(run.get('iterations', []))} / {run.get('autonomous_iterations_amount', 1)}")
|
||||
|
||||
if run['status'] == "in_progress":
|
||||
if st.button(f"Stop Run {run['id']}", key=f"stop_thread_run_{run['id']}"):
|
||||
stop_thread_run(thread_id, run['id'])
|
||||
st.rerun()
|
||||
|
||||
if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_thread_run_{run['id']}"):
|
||||
updated_run = get_thread_run_status(thread_id, run['id'])
|
||||
if updated_run:
|
||||
run.update(updated_run)
|
||||
st.rerun()
|
||||
|
||||
def fetch_thread_runs(thread_id):
|
||||
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
st.error("Failed to fetch thread runs.")
|
||||
return []
|
||||
|
||||
def format_timestamp(timestamp):
|
||||
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
|
@ -1,195 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from agentpress.ui.utils import API_BASE_URL
|
||||
from datetime import datetime
|
||||
|
||||
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
|
||||
return {
|
||||
"system_message": {"role": "system", "content": system_message},
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": selected_tools,
|
||||
"additional_system_message": additional_system_message,
|
||||
"tool_choice": "auto" # Add this line to ensure tool_choice is always set
|
||||
}
|
||||
|
||||
def prepare_run_thread_agent_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions):
|
||||
return {
|
||||
"system_message": {"role": "system", "content": system_message},
|
||||
"model_name": model_name,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": selected_tools,
|
||||
"additional_system_message": additional_system_message,
|
||||
"autonomous_iterations_amount": autonomous_iterations_amount,
|
||||
"continue_instructions": continue_instructions
|
||||
}
|
||||
|
||||
def run_thread(thread_id, run_thread_data):
|
||||
with st.spinner("Running thread..."):
|
||||
try:
|
||||
run_thread_response = requests.post(
|
||||
f"{API_BASE_URL}/threads/{thread_id}/run/",
|
||||
json=run_thread_data
|
||||
)
|
||||
run_thread_response.raise_for_status()
|
||||
response_data = run_thread_response.json()
|
||||
st.success(f"Thread run completed successfully! Status: {response_data.get('status', 'Unknown')}")
|
||||
|
||||
if 'id' in response_data:
|
||||
st.session_state.latest_run_id = response_data['id']
|
||||
|
||||
st.subheader("Response Content")
|
||||
display_response_content(response_data)
|
||||
|
||||
# Display the full response data
|
||||
st.subheader("Full Response Data")
|
||||
st.json(response_data)
|
||||
|
||||
return response_data # Return the response data
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.error(f"Failed to run thread. Error: {str(e)}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
st.text("Response content:")
|
||||
st.text(e.response.text)
|
||||
except Exception as e:
|
||||
st.error(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
return None # Return None if there was an error
|
||||
|
||||
def run_thread_agent(thread_id, run_thread_agent_data):
|
||||
with st.spinner("Running thread agent..."):
|
||||
try:
|
||||
run_thread_response = requests.post(
|
||||
f"{API_BASE_URL}/threads/{thread_id}/run_agent/",
|
||||
json=run_thread_agent_data
|
||||
)
|
||||
run_thread_response.raise_for_status()
|
||||
response_data = run_thread_response.json()
|
||||
st.success(f"Thread agent run completed successfully! Status: {response_data['status']}")
|
||||
|
||||
st.subheader("Agent Response")
|
||||
display_agent_response_content(response_data)
|
||||
|
||||
# Display the full response data
|
||||
st.subheader("Full Agent Response Data")
|
||||
st.json(response_data)
|
||||
|
||||
return response_data
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.error(f"Failed to run thread agent. Error: {str(e)}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
st.text("Response content:")
|
||||
st.text(e.response.text)
|
||||
except Exception as e:
|
||||
st.error(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
def display_response_content(response_data):
|
||||
if isinstance(response_data, dict) and 'choices' in response_data:
|
||||
message = response_data['choices'][0]['message']
|
||||
st.write(f"**Role:** {message['role']}")
|
||||
st.write(f"**Content:** {message['content']}")
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
st.write("**Tool Calls:**")
|
||||
for tool_call in message['tool_calls']:
|
||||
st.write(f"- Function: `{tool_call['function']['name']}`")
|
||||
st.code(tool_call['function']['arguments'], language="json")
|
||||
else:
|
||||
st.json(response_data)
|
||||
|
||||
def display_agent_response_content(response_data):
|
||||
st.write(f"**Status:** {response_data['status']}")
|
||||
st.write(f"**Total Iterations:** {response_data['total_iterations']}")
|
||||
st.write(f"**Completed Iterations:** {response_data.get('iterations_count', 'N/A')}")
|
||||
|
||||
for i, iteration in enumerate(response_data['iterations']):
|
||||
with st.expander(f"Iteration {i+1}"):
|
||||
display_response_content(iteration)
|
||||
|
||||
st.write("**Final Configuration:**")
|
||||
st.json(response_data['final_config'])
|
||||
|
||||
def fetch_thread_runs(thread_id, limit):
|
||||
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
st.error("Failed to retrieve runs.")
|
||||
return []
|
||||
|
||||
def format_timestamp(timestamp):
|
||||
if timestamp:
|
||||
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||
return 'N/A'
|
||||
|
||||
def display_runs(runs):
|
||||
for run in runs:
|
||||
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
|
||||
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
|
||||
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
|
||||
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
|
||||
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
|
||||
with col2:
|
||||
st.write(f"**Model:** {run['model']}")
|
||||
st.write(f"**Temperature:** {run['temperature']}")
|
||||
st.write(f"**Top P:** {run['top_p']}")
|
||||
st.write(f"**Max Tokens:** {run['max_tokens']}")
|
||||
st.write(f"**Tool Choice:** {run['tool_choice']}")
|
||||
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
|
||||
st.write(f"**Autonomous Iterations:** {run['autonomous_iterations_amount']}")
|
||||
|
||||
st.write("**System Message:**")
|
||||
st.json(run['system_message'])
|
||||
|
||||
if run['tools']:
|
||||
st.write("**Tools:**")
|
||||
st.json(run['tools'])
|
||||
|
||||
if run['usage']:
|
||||
st.write("**Usage:**")
|
||||
st.json(run['usage'])
|
||||
|
||||
if run['response_format']:
|
||||
st.write("**Response Format:**")
|
||||
st.json(run['response_format'])
|
||||
|
||||
if run['last_error']:
|
||||
st.error("**Last Error:**")
|
||||
st.code(run['last_error'])
|
||||
|
||||
if run['continue_instructions']:
|
||||
st.write("**Continue Instructions:**")
|
||||
st.text(run['continue_instructions'])
|
||||
|
||||
if run['status'] == "in_progress":
|
||||
if st.button(f"Stop Run {run['id']}", key=f"stop_button_{run['id']}"):
|
||||
stop_thread_run(run['thread_id'], run['id'])
|
||||
st.rerun()
|
||||
|
||||
if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_button_{run['id']}"):
|
||||
updated_run = get_thread_run_status(run['thread_id'], run['id'])
|
||||
if updated_run:
|
||||
run.update(updated_run)
|
||||
st.rerun()
|
||||
|
||||
def stop_thread_run(thread_id, run_id):
|
||||
response = requests.post(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/stop")
|
||||
if response.status_code == 200:
|
||||
st.success("Thread run stopped successfully.")
|
||||
return response.json()
|
||||
else:
|
||||
st.error(f"Failed to stop thread run. Status code: {response.status_code}")
|
||||
return None
|
||||
|
||||
def get_thread_run_status(thread_id, run_id):
|
||||
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/status")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
st.error(f"Failed to get thread run status. Status code: {response.status_code}")
|
||||
return None
|
|
@ -1,43 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from agentpress.ui.utils import API_BASE_URL
|
||||
|
||||
def display_tools():
|
||||
st.header("Available Tools")
|
||||
|
||||
tools = fetch_tools()
|
||||
|
||||
if not tools:
|
||||
st.warning("No tools available. Please check the API connection.")
|
||||
return
|
||||
|
||||
view_mode = st.radio("View Mode", ["Simple", "Detailed"])
|
||||
|
||||
if view_mode == "Simple":
|
||||
display_simple_view(tools)
|
||||
else:
|
||||
display_detailed_view(tools)
|
||||
|
||||
def display_simple_view(tools):
|
||||
for tool_name, tool_info in tools.items():
|
||||
with st.expander(f"🛠️ {tool_name}"):
|
||||
st.write(f"**Description:** {tool_info['description']}")
|
||||
|
||||
def display_detailed_view(tools):
|
||||
for tool_name, tool_info in tools.items():
|
||||
with st.expander(f"🛠️ {tool_name}"):
|
||||
st.write(f"**Description:** {tool_info['description']}")
|
||||
if tool_info['schema']:
|
||||
st.write("**Schema:**")
|
||||
st.json(tool_info['schema'])
|
||||
|
||||
def fetch_tools():
|
||||
response = requests.get(f"{API_BASE_URL}/tools/")
|
||||
if response.status_code == 200:
|
||||
tools = response.json()
|
||||
st.session_state.tools = tools
|
||||
return tools
|
||||
else:
|
||||
st.error(f"Failed to fetch tools. Status code: {response.status_code}")
|
||||
st.error(f"Error message: {response.text}")
|
||||
return {}
|
|
@ -1,23 +0,0 @@
|
|||
import streamlit as st
|
||||
import requests
|
||||
from agentpress.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
|
||||
def fetch_tools():
|
||||
response = requests.get(f"{API_BASE_URL}/tools/")
|
||||
if response.status_code == 200:
|
||||
st.session_state.tools = response.json()
|
||||
else:
|
||||
st.error("Failed to fetch tools.")
|
||||
|
||||
def initialize_session_state():
|
||||
if 'selected_thread' not in st.session_state:
|
||||
st.session_state.selected_thread = None
|
||||
if 'tools' not in st.session_state:
|
||||
st.session_state.tools = []
|
||||
if 'fetch_tools' not in st.session_state:
|
||||
st.session_state.fetch_tools = fetch_tools
|
||||
|
||||
def fetch_data():
|
||||
fetch_tools()
|
Loading…
Reference in New Issue