From 3f69ea9cc446b898eae55c5e5914c9d8bfd8ace5 Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Wed, 23 Oct 2024 03:28:12 +0200 Subject: [PATCH] v1 --- agentpress/__init__.py | 4 +- agentpress/api.py | 182 -------- agentpress/db.py | 33 -- agentpress/llm.py | 13 +- agentpress/thread_manager.py | 681 +++++------------------------ agentpress/tool.py | 134 +++++- agentpress/tool_registry.py | 57 ++- agentpress/ui/__init__.py | 4 - agentpress/ui/main.py | 44 -- agentpress/ui/message_display.py | 106 ----- agentpress/ui/real_time_updates.py | 22 - agentpress/ui/thread_management.py | 87 ---- agentpress/ui/thread_runner.py | 195 --------- agentpress/ui/tool_display.py | 43 -- agentpress/ui/utils.py | 23 - 15 files changed, 273 insertions(+), 1355 deletions(-) delete mode 100644 agentpress/api.py delete mode 100644 agentpress/ui/__init__.py delete mode 100644 agentpress/ui/main.py delete mode 100644 agentpress/ui/message_display.py delete mode 100644 agentpress/ui/real_time_updates.py delete mode 100644 agentpress/ui/thread_management.py delete mode 100644 agentpress/ui/thread_runner.py delete mode 100644 agentpress/ui/tool_display.py delete mode 100644 agentpress/ui/utils.py diff --git a/agentpress/__init__.py b/agentpress/__init__.py index 437ac0c6..349e1a2f 100644 --- a/agentpress/__init__.py +++ b/agentpress/__init__.py @@ -1,10 +1,10 @@ from .config import settings -from .db import Database, Thread, ThreadRun +from .db import Database, Thread from .llm import make_llm_api_call from .thread_manager import ThreadManager # from .working_memory_manager import WorkingMemory __all__ = [ - 'settings', 'Database', 'Thread', 'ThreadRun', + 'settings', 'Database', 'Thread', 'make_llm_api_call', 'ThreadManager' ] #'WorkingMemory' \ No newline at end of file diff --git a/agentpress/api.py b/agentpress/api.py deleted file mode 100644 index 8275d919..00000000 --- a/agentpress/api.py +++ /dev/null @@ -1,182 +0,0 @@ -from fastapi import FastAPI, HTTPException, Query, Path -from pydantic import BaseModel, Field -from typing import List, Dict, Any, Optional -import asyncio -import json - -from agentpress.db import Database -from agentpress.thread_manager import ThreadManager -from agentpress.tool_registry import ToolRegistry -from agentpress.config import Settings - -app = FastAPI( - title="Thread Manager API", - description="API for managing and running threads with LLM integration", - version="1.0.0", -) - -db = Database() -manager = ThreadManager(db) -tool_registry = ToolRegistry() - -class Message(BaseModel): - role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')") - content: str = Field(..., description="The content of the message") - -class RunThreadRequest(BaseModel): - system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run") - model_name: str = Field(..., description="The name of the LLM model to be used") - temperature: float = Field(0.5, description="The sampling temperature for the LLM") - max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate") - tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run") - tool_choice: str = Field("auto", description="Controls which tool is called by the model") - additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended to the existing system message") - additional_message: Optional[Dict[str, Any]] = Field(None, description="Additional message to be appended at the end of the conversation") - hide_tool_msgs: bool = Field(False, description="Whether to hide tool messages in the conversation history") - execute_tools_async: bool = Field(True, description="Whether to execute tools asynchronously") - use_tool_parser: bool = Field(False, description="Whether to use the tool parser for handling tool calls") - top_p: Optional[float] = Field(None, description="The nucleus sampling value") - response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output") - autonomous_iterations_amount: Optional[int] = Field(None, description="The number of autonomous iterations to perform") - continue_instructions: Optional[str] = Field(None, description="Instructions for continuing the conversation in subsequent iterations") - initializer: Optional[str] = Field(None, description="Name of the initializer function") - pre_iteration: Optional[str] = Field(None, description="Name of the pre-iteration function") - after_iteration: Optional[str] = Field(None, description="Name of the after-iteration function") - finalizer: Optional[str] = Field(None, description="Name of the finalizer function") - -@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread") -async def create_thread(): - """ - Create a new thread and return its ID. - """ - thread_id = await manager.create_thread() - return {"thread_id": thread_id} - -@app.get("/threads/", response_model=List[Dict[str, Any]], summary="Get all threads") -async def get_threads(): - """ - Retrieve a list of all threads. - """ - threads = await manager.get_threads() - return [{"thread_id": thread.thread_id, "created_at": thread.created_at} for thread in threads] - -@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread") -async def add_message(thread_id: str, message: Message): - """ - Add a new message to the specified thread. - """ - await manager.add_message(thread_id, message.dict()) - return {"status": "success"} - -@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread") -async def list_messages(thread_id: str): - """ - Retrieve all messages from the specified thread. - """ - messages = await manager.list_messages(thread_id) - return messages - -@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread") -async def run_thread(thread_id: str, request: RunThreadRequest): - try: - # Create a new ThreadRun object - thread_run = await manager.create_thread_run( - thread_id, - model_name=request.model_name, - temperature=request.temperature, - max_tokens=request.max_tokens, - top_p=request.top_p, - tool_choice=request.tool_choice, - execute_tools_async=request.execute_tools_async, - system_message=json.dumps(request.system_message), - tools=json.dumps(request.tools), - response_format=json.dumps(request.response_format), - autonomous_iterations_amount=request.autonomous_iterations_amount, - continue_instructions=request.continue_instructions - ) - - # Run the thread with the created ThreadRun object - result = await manager.run_thread( - thread_id=thread_id, - thread_run=thread_run, - initializer=get_function(request.initializer), - pre_iteration=get_function(request.pre_iteration), - after_iteration=get_function(request.after_iteration), - finalizer=get_function(request.finalizer), - **request.dict(exclude={'initializer', 'pre_iteration', 'after_iteration', 'finalizer'}) - ) - return result - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -def get_function(function_name: Optional[str]): - if function_name is None: - return None - # Implement a way to get the function by name, e.g., from a predefined dictionary of functions - # For now, we'll return None - return None - -@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools") -async def get_tools(): - """ - Retrieve a list of all available tools and their schemas. - """ - tools = tool_registry.get_all_tools() - if not tools: - print("No tools found in the registry") # Debug print - return { - name: { - "name": name, - "description": tool_info['schema']['function']['description'], - "schema": tool_info['schema'] - } - for name, tool_info in tools.items() - } - -@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run") -async def get_run( - thread_id: str = Path(..., description="The ID of the thread that was run"), - run_id: str = Path(..., description="The ID of the run to retrieve") -): - run = await manager.get_run(thread_id, run_id) - if run is None: - raise HTTPException(status_code=404, detail="Run not found") - return run - -@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run") -async def cancel_run( - thread_id: str = Path(..., description="The ID of the thread to which this run belongs"), - run_id: str = Path(..., description="The ID of the run to cancel") -): - """ - Cancels a run that is in_progress. - """ - run = await manager.cancel_run(thread_id, run_id) - if run is None: - raise HTTPException(status_code=404, detail="Run not found") - return run - -@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs") -async def list_runs( - thread_id: str = Path(..., description="The ID of the thread the runs belong to"), - limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned") -): - runs = await manager.list_runs(thread_id, limit) - return runs - -@app.post("/threads/{thread_id}/runs/{run_id}/stop", response_model=Dict[str, Any], summary="Stop a thread run") -async def stop_thread_run( - thread_id: str = Path(..., description="The ID of the thread"), - run_id: str = Path(..., description="The ID of the run to stop") -): - """ - Stops a thread run that is in progress. - """ - run = await manager.stop_thread_run(thread_id, run_id) - if run is None: - raise HTTPException(status_code=404, detail="Run not found or already completed/stopped") - return run - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/agentpress/db.py b/agentpress/db.py index 2e246058..86e5a6ac 100644 --- a/agentpress/db.py +++ b/agentpress/db.py @@ -17,43 +17,10 @@ class Thread(Base): messages = Column(Text) created_at = Column(Integer) - runs = relationship("ThreadRun", back_populates="thread") - def __init__(self, **kwargs): super().__init__(**kwargs) self.created_at = int(datetime.utcnow().timestamp()) -class ThreadRun(Base): - __tablename__ = "thread_runs" - - id = Column(String, primary_key=True) - thread_id = Column(String, ForeignKey("threads.thread_id")) - status = Column(String, default="queued") - last_error = Column(String, nullable=True) - created_at = Column(Integer) - started_at = Column(Integer, nullable=True) - cancelled_at = Column(Integer, nullable=True) - failed_at = Column(Integer, nullable=True) - completed_at = Column(Integer, nullable=True) - model = Column(String) - temperature = Column(Float) - max_tokens = Column(Integer, nullable=True) - top_p = Column(Float, nullable=True) - tool_choice = Column(String) - execute_tools_async = Column(Boolean) - system_message = Column(JSON) - tools = Column(JSON, nullable=True) - usage = Column(JSON, nullable=True) - response_format = Column(JSON, nullable=True) - autonomous_iterations_amount = Column(Integer, nullable=True) - continue_instructions = Column(String, nullable=True) - iterations = Column(JSON, nullable=True) - - thread = relationship("Thread", back_populates="runs") - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.created_at = int(datetime.utcnow().timestamp()) # class MemoryModule(Base): # __tablename__ = 'memory_modules' diff --git a/agentpress/llm.py b/agentpress/llm.py index 25ac4e73..fa6b6cad 100644 --- a/agentpress/llm.py +++ b/agentpress/llm.py @@ -26,7 +26,7 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]: +async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]: litellm.set_verbose = True async def attempt_api_call(api_call_func, max_attempts=3): @@ -69,14 +69,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0 api_call_params["max_tokens"] = max_tokens if tools: - if use_tool_parser: - # Add tools as user messages - tools_message = {"role": "user", "content": json.dumps(tools)} - api_call_params["messages"].append(tools_message) - else: - # Use the existing method of adding tools - api_call_params["tools"] = tools - api_call_params["tool_choice"] = tool_choice + # Use the existing method of adding tools + api_call_params["tools"] = tools + api_call_params["tool_choice"] = tool_choice if "claude" in model_name.lower() or "anthropic" in model_name.lower(): api_call_params["extra_headers"] = { diff --git a/agentpress/thread_manager.py b/agentpress/thread_manager.py index c35e23e3..1cdf3660 100644 --- a/agentpress/thread_manager.py +++ b/agentpress/thread_manager.py @@ -1,22 +1,27 @@ import json import logging import asyncio -from typing import List, Dict, Any, Optional, Callable +from typing import List, Dict, Any, Optional, Callable, Type from sqlalchemy import select -from agentpress.db import Database, Thread, ThreadRun +from agentpress.db import Database, Thread from agentpress.llm import make_llm_api_call from datetime import datetime, UTC -from agentpress.tool import ToolResult +from agentpress.tool import Tool, ToolResult from agentpress.tool_registry import ToolRegistry import uuid -from tools.files_tool import FilesTool class ThreadManager: - def __init__(self, db: Database): - self.db = db + def __init__(self, db: Optional[Database] = None): + self.db = db if db is not None else Database() self.tool_registry = ToolRegistry() - self.run_config: Dict[str, Any] = {} - self.current_iteration: int = 0 + + def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None): + """ + Add a tool to the ThreadManager. + If function_names is provided, only register those specific functions. + If function_names is None, register all functions from the tool. + """ + self.tool_registry.register_tool(tool_class, function_names) async def create_thread(self) -> int: async with self.db.get_async_session() as session: @@ -25,7 +30,7 @@ class ThreadManager: ) session.add(new_thread) await session.commit() - await session.refresh(new_thread) # Ensure thread_id is populated + await session.refresh(new_thread) return new_thread.thread_id async def add_message(self, thread_id: int, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None): @@ -38,9 +43,7 @@ class ThreadManager: try: messages = json.loads(thread.messages) - # If we're adding a user message, perform checks if message_data['role'] == 'user': - # Find the last assistant message with tool calls last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None) if last_assistant_index is not None: @@ -50,12 +53,10 @@ class ThreadManager: if tool_call_count != tool_response_count: await self.cleanup_incomplete_tool_calls(thread_id) - # Convert ToolResult objects to strings for key, value in message_data.items(): if isinstance(value, ToolResult): message_data[key] = str(value) - # Process images if present if images: if isinstance(message_data['content'], str): message_data['content'] = [{"type": "text", "text": message_data['content']}] @@ -81,50 +82,6 @@ class ThreadManager: logging.error(f"Failed to add message to thread {thread_id}: {e}") raise e - async def get_message(self, thread_id: int, message_index: int) -> Optional[Dict[str, Any]]: - async with self.db.get_async_session() as session: - thread = await session.get(Thread, thread_id) - if not thread: - return None - messages = json.loads(thread.messages) - if message_index < len(messages): - return messages[message_index] - return None - - async def modify_message(self, thread_id: int, message_index: int, new_message_data: Dict[str, Any]): - async with self.db.get_async_session() as session: - thread = await session.get(Thread, thread_id) - if not thread: - raise ValueError(f"Thread with id {thread_id} not found") - - try: - messages = json.loads(thread.messages) - if message_index < len(messages): - messages[message_index] = new_message_data - thread.messages = json.dumps(messages) - await session.commit() - else: - raise ValueError(f"Message index {message_index} is out of range") - except Exception as e: - await session.rollback() - raise e - - async def remove_message(self, thread_id: int, message_index: int): - async with self.db.get_async_session() as session: - thread = await session.get(Thread, thread_id) - if not thread: - raise ValueError(f"Thread with id {thread_id} not found") - - try: - messages = json.loads(thread.messages) - if message_index < len(messages): - del messages[message_index] - thread.messages = json.dumps(messages) - await session.commit() - except Exception as e: - await session.rollback() - raise e - async def list_messages(self, thread_id: int, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]: async with self.db.get_async_session() as session: thread = await session.get(Thread, thread_id) @@ -138,7 +95,7 @@ class ThreadManager: return [msg] return [] - filtered_messages = messages # Initialize filtered_messages with all messages + filtered_messages = messages if hide_tool_msgs: filtered_messages = [ @@ -164,7 +121,6 @@ class ThreadManager: tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool'] if len(tool_calls) != len(tool_responses): - # Create failed ToolResults for incomplete tool calls failed_tool_results = [] for tool_call in tool_calls[len(tool_responses):]: failed_tool_result = { @@ -175,7 +131,6 @@ class ThreadManager: } failed_tool_results.append(failed_tool_result) - # Insert failed tool results after the last assistant message assistant_index = messages.index(last_assistant_message) messages[assistant_index+1:assistant_index+1] = failed_tool_results @@ -188,236 +143,99 @@ class ThreadManager: return True return False - async def run_thread(self, settings: Dict[str, Any]) -> Dict[str, Any]: - try: - thread_run = ThreadRun( - id=str(uuid.uuid4()), - thread_id=settings['thread_id'], - status="queued", - model=settings['model_name'], - temperature=settings.get('temperature', 0.7), - max_tokens=settings.get('max_tokens'), - top_p=settings.get('top_p'), - tool_choice=settings.get('tool_choice', 'auto'), - execute_tools_async=settings.get('execute_tools_async', True), - system_message=json.dumps(settings['system_message']), - tools=json.dumps(settings.get('tools')), - response_format=json.dumps(settings.get('response_format')), - autonomous_iterations_amount=settings.get('autonomous_iterations_amount', 1), - continue_instructions=settings.get('continue_instructions') - ) - - async with self.db.get_async_session() as session: - session.add(thread_run) - await session.commit() - - thread_run.status = "in_progress" - thread_run.started_at = int(datetime.now(UTC).timestamp()) - await self.update_thread_run(thread_run) - - self.run_config = {k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')} - self.run_config['iterations'] = [] - - if settings.get('initializer'): - settings['initializer']() - # Update thread_run with changes from run_config - for key, value in self.run_config.items(): - setattr(thread_run, key, value) - await self.update_thread_run(thread_run) - - full_tools = None - if settings.get('tools'): - full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in settings['tools'] if self.tool_registry.get_tool(tool_name)] - - self.current_iteration = 0 - for iteration in range(settings.get('autonomous_iterations_amount', 1)): - self.current_iteration = iteration + 1 - - if await self.should_stop(settings['thread_id'], thread_run.id): - thread_run.status = "stopped" - thread_run.cancelled_at = int(datetime.now(UTC).timestamp()) - await self.update_thread_run(thread_run) - return {"status": "stopped", "message": "Thread run cancelled"} - - if settings.get('pre_iteration'): - settings['pre_iteration']() - # Update thread_run with changes from run_config - for key, value in self.run_config.items(): - setattr(thread_run, key, value) - await self.update_thread_run(thread_run) - - if iteration > 0 and settings.get('continue_instructions'): - await self.add_message(settings['thread_id'], {"role": "user", "content": settings['continue_instructions']}) - - messages = await self.list_messages(settings['thread_id'], hide_tool_msgs=settings.get('hide_tool_msgs', False)) - prepared_messages = [settings['system_message']] + messages - - if settings.get('additional_message'): - prepared_messages.append(settings['additional_message']) - - response = await make_llm_api_call( - prepared_messages, - settings['model_name'], - temperature=thread_run.temperature, - max_tokens=thread_run.max_tokens, - tools=full_tools, - tool_choice=thread_run.tool_choice, - stream=False, - top_p=thread_run.top_p, - response_format=json.loads(thread_run.response_format) if thread_run.response_format else None - ) - - usage = response.usage if hasattr(response, 'usage') else None - usage_dict = self.serialize_usage(usage) if usage else None - thread_run.usage = usage_dict - - assistant_message = { - "role": "assistant", - "content": response.choices[0].message['content'] - } - if 'tool_calls' in response.choices[0].message: - assistant_message['tool_calls'] = response.choices[0].message['tool_calls'] - - await self.add_message(settings['thread_id'], assistant_message) - - if settings.get('tools') is None or settings.get('use_tool_parser', False): - await self.handle_response_without_tools(settings['thread_id'], response, settings.get('use_tool_parser', False)) - else: - await self.handle_response_with_tools(settings['thread_id'], response, settings.get('execute_tools_async', True)) - - self.run_config['iterations'].append({ - "iteration": self.current_iteration, - "response": self.serialize_choice(response.choices[0]), - "usage": usage_dict - }) - - if settings.get('after_iteration'): - settings['after_iteration']() - # Update thread_run with changes from run_config - for key, value in self.run_config.items(): - setattr(thread_run, key, value) - - thread_run.iterations = json.dumps(self.run_config['iterations']) - await self.update_thread_run(thread_run) - - thread_run.status = "completed" - thread_run.completed_at = int(datetime.now(UTC).timestamp()) - await self.update_thread_run(thread_run) - - self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')}) - - if settings.get('finalizer'): - settings['finalizer']() - # Update thread_run with final changes from run_config - for key, value in self.run_config.items(): - setattr(thread_run, key, value) - await self.update_thread_run(thread_run) - - return { - "id": thread_run.id, - "status": thread_run.status, - "iterations": self.run_config['iterations'], - "total_iterations": len(self.run_config['iterations']), - "usage": thread_run.usage, - "model": settings['model_name'], - "object": "chat.completion", - "created": int(datetime.now(UTC).timestamp()) - } - except Exception as e: - thread_run.status = "failed" - thread_run.failed_at = int(datetime.now(UTC).timestamp()) - thread_run.last_error = str(e) - await self.update_thread_run(thread_run) - self.run_config.update({k: v for k, v in thread_run.__dict__.items() if not k.startswith('_')}) - if settings.get('finalizer'): - settings['finalizer']() - raise - - async def update_thread_run(self, thread_run: ThreadRun): - async with self.db.get_async_session() as session: - session.add(thread_run) - await session.commit() - await session.refresh(thread_run) - - async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool): - response_content = response.choices[0].message['content'] + async def run_thread(self, thread_id: int, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = True) -> Dict[str, Any]: - if use_tool_parser: - await self.handle_tool_parser_response(thread_id, response_content) - else: - # The message has already been added in the run_thread method, so we don't need to add it again here - pass + messages = await self.list_messages(thread_id) + prepared_messages = [system_message] + messages + + if additional_message: + prepared_messages.append(additional_message) + + tools = self.tool_registry.get_all_tool_schemas() if use_tools else None - async def handle_tool_parser_response(self, thread_id: int, response_content: str): - tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content) - if tool_call_match: - try: - tool_call_json = json.loads(tool_call_match.group()) - tool_calls = tool_call_json.get('function_calls', []) - - assistant_message = { - "role": "assistant", - "content": response_content, - "tool_calls": [ - { - "id": f"call_{i}", - "type": "function", - "function": { - "name": call['name'], - "arguments": json.dumps(call['arguments']) - } - } for i, call in enumerate(tool_calls) - ] + try: + llm_response = await make_llm_api_call( + prepared_messages, + model_name, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + tool_choice=tool_choice if use_tools else None, + stream=False + ) + except Exception as e: + logging.error(f"Error in API call: {str(e)}") + return { + "status": "error", + "message": str(e), + "run_thread_params": { + "thread_id": thread_id, + "system_message": system_message, + "model_name": model_name, + "temperature": temperature, + "max_tokens": max_tokens, + "tool_choice": tool_choice, + "additional_message": additional_message, + "execute_tools_async": execute_tools_async, + "execute_model_tool_calls": execute_model_tool_calls, + "use_tools": use_tools } - await self.add_message(thread_id, assistant_message) + } - available_functions = self.get_available_functions() - - tool_results = await self.execute_tools(assistant_message['tool_calls'], available_functions, thread_id, execute_tools_async=True) - - await self.process_tool_results(thread_id, tool_results) - - except json.JSONDecodeError: - logging.error("Failed to parse tool call JSON from response") - await self.add_message(thread_id, {"role": "assistant", "content": response_content}) + if use_tools and execute_model_tool_calls: + await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async) else: - await self.add_message(thread_id, {"role": "assistant", "content": response_content}) + await self.handle_response_without_tools(thread_id, llm_response) + + return { + "llm_response": llm_response, + "run_thread_params": { + "thread_id": thread_id, + "system_message": system_message, + "model_name": model_name, + "temperature": temperature, + "max_tokens": max_tokens, + "tool_choice": tool_choice, + "additional_message": additional_message, + "execute_tools_async": execute_tools_async, + "execute_model_tool_calls": execute_model_tool_calls, + "use_tools": use_tools + } + } + + async def handle_response_without_tools(self, thread_id: int, response: Any): + response_content = response.choices[0].message['content'] + await self.add_message(thread_id, {"role": "assistant", "content": response_content}) async def handle_response_with_tools(self, thread_id: int, response: Any, execute_tools_async: bool): try: response_message = response.choices[0].message tool_calls = response_message.get('tool_calls', []) - # The assistant message has already been added in the run_thread method - + assistant_message = self.create_assistant_message_with_tools(response_message) + await self.add_message(thread_id, assistant_message) + available_functions = self.get_available_functions() - if await self.should_stop(thread_id, thread_id): - return {"status": "stopped", "message": "Session cancelled"} - if tool_calls: if execute_tools_async: tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id) else: tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id) - # Add tool results to messages for result in tool_results: await self.add_message(thread_id, result) - - if await self.should_stop(thread_id, thread_id): - return {"status": "stopped", "message": "Session cancelled after tool execution"} except AttributeError as e: logging.error(f"AttributeError: {e}") - # No need to add the message here as it's already been added in the run_thread method + response_content = response.choices[0].message['content'] + await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""}) def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]: message = { "role": "assistant", "content": response_message.get('content') or "", } - tool_calls = response_message.get('tool_calls') if tool_calls: message["tool_calls"] = [ @@ -441,33 +259,12 @@ class ThreadManager: available_functions[func_name] = getattr(tool_instance, func_name) return available_functions - async def execute_tools(self, tool_calls: List[Any], available_functions: Dict[str, Callable], thread_id: int, execute_tools_async: bool) -> List[Dict[str, Any]]: - if execute_tools_async: - return await self.execute_tools_async(tool_calls, available_functions, thread_id) - else: - return await self.execute_tools_sync(tool_calls, available_functions, thread_id) - async def execute_tools_async(self, tool_calls, available_functions, thread_id): async def execute_single_tool(tool_call): - if await self.should_stop(thread_id, thread_id): - return {"status": "stopped", "message": "Session cancelled"} - function_name = tool_call.function.name + function_args = json.loads(tool_call.function.arguments) tool_call_id = tool_call.id - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError as e: - error_message = f"Error parsing arguments for {function_name}: {str(e)}" - logging.error(error_message) - logging.error(f"Problematic JSON: {tool_call.function.arguments}") - return { - "role": "tool", - "tool_call_id": tool_call_id, - "name": function_name, - "content": str(ToolResult(success=False, output=error_message)), - } - function_to_call = available_functions.get(function_name) if function_to_call: return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id) @@ -481,9 +278,6 @@ class ThreadManager: async def execute_tools_sync(self, tool_calls, available_functions, thread_id): tool_results = [] for tool_call in tool_calls: - if await self.should_stop(thread_id, thread_id): - return [{"status": "stopped", "message": "Session cancelled"}] - function_name = tool_call.function.name function_args = json.loads(tool_call.function.arguments) tool_call_id = tool_call.id @@ -498,10 +292,6 @@ class ThreadManager: return tool_results - async def process_tool_results(self, thread_id: int, tool_results: List[Dict[str, Any]]): - for result in tool_results: - await self.add_message(thread_id, result['tool_message']) - async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id): try: function_response = await function_to_call(**function_args) @@ -516,297 +306,60 @@ class ThreadManager: "content": str(function_response), } - async def should_stop(self, thread_id: str, run_id: str) -> bool: - async with self.db.get_async_session() as session: - run = await session.get(ThreadRun, run_id) - if run and run.status in ["stopped", "cancelled", "queued"]: - return True - return False - - async def stop_thread_run(self, thread_id: str, run_id: str) -> Dict[str, Any]: - async with self.db.get_async_session() as session: - run = await session.get(ThreadRun, run_id) - if run and run.thread_id == thread_id and run.status == "in_progress": - run.status = "stopping" - await session.commit() - return self.serialize_thread_run(run) - return None - - async def save_thread_run(self, thread_id: str): - async with self.db.get_async_session() as session: - thread = await session.get(Thread, thread_id) - if not thread: - raise ValueError(f"Thread with id {thread_id} not found") - - messages = json.loads(thread.messages) - creation_date = datetime.now().isoformat() - - # Get the latest ThreadRun for this thread - stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1) - result = await session.execute(stmt) - latest_thread_run = result.scalar_one_or_none() - - if latest_thread_run: - # Update the existing ThreadRun - latest_thread_run.messages = json.dumps(messages) - latest_thread_run.last_updated_date = creation_date - await session.commit() - else: - # Create a new ThreadRun if none exists - new_thread_run = ThreadRun( - thread_id=thread_id, - messages=json.dumps(messages), - creation_date=creation_date, - status='completed' - ) - session.add(new_thread_run) - await session.commit() - async def get_thread(self, thread_id: int) -> Optional[Thread]: async with self.db.get_async_session() as session: return await session.get(Thread, thread_id) - async def update_thread_run_with_error(self, thread_id: int, error_message: str): - async with self.db.get_async_session() as session: - stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1) - result = await session.execute(stmt) - thread_run = result.scalar_one_or_none() - if thread_run: - thread_run.status = 'error' - thread_run.error_message = error_message # Store the full error message - await session.commit() - - async def get_threads(self) -> List[Thread]: - async with self.db.get_async_session() as session: - result = await session.execute(select(Thread).order_by(Thread.thread_id.desc())) - return result.scalars().all() - - async def get_latest_thread_run(self, thread_id: str): - async with self.db.get_async_session() as session: - stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1) - result = await session.execute(stmt) - latest_run = result.scalar_one_or_none() - if latest_run: - return { - "id": latest_run.id, - "status": latest_run.status, - "error_message": latest_run.last_error, - "created_at": latest_run.created_at, - "started_at": latest_run.started_at, - "completed_at": latest_run.completed_at, - "cancelled_at": latest_run.cancelled_at, - "failed_at": latest_run.failed_at, - "model": latest_run.model, - "usage": latest_run.usage - } - return None - - async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]: - async with self.db.get_async_session() as session: - run = await session.get(ThreadRun, run_id) - if run and run.thread_id == thread_id: - return { - "id": run.id, - "thread_id": run.thread_id, - "status": run.status, - "created_at": run.created_at, - "started_at": run.started_at, - "completed_at": run.completed_at, - "cancelled_at": run.cancelled_at, - "failed_at": run.failed_at, - "model": run.model, - "system_message": json.loads(run.system_message) if run.system_message else None, - "tools": json.loads(run.tools) if run.tools else None, - "usage": run.usage, - "temperature": run.temperature, - "top_p": run.top_p, - "max_tokens": run.max_tokens, - "tool_choice": run.tool_choice, - "execute_tools_async": run.execute_tools_async, - "response_format": json.loads(run.response_format) if run.response_format else None, - "last_error": run.last_error - } - return None - - async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]: - async with self.db.get_async_session() as session: - run = await session.get(ThreadRun, run_id) - if run and run.thread_id == thread_id and run.status == "in_progress": - run.status = "cancelled" - run.cancelled_at = int(datetime.now(UTC).timestamp()) - await session.commit() - return await self.get_run(thread_id, run_id) - return None - - async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]: - async with self.db.get_async_session() as session: - thread_runs_stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit) - thread_runs_result = await session.execute(thread_runs_stmt) - thread_runs = thread_runs_result.scalars().all() - return [self.serialize_thread_run(run) for run in thread_runs] - - async def create_thread_run(self, thread_id: str, **kwargs) -> ThreadRun: - run_id = str(uuid.uuid4()) - thread_run = ThreadRun( - id=run_id, - thread_id=thread_id, - status="queued", - model=kwargs.get('model_name'), - temperature=kwargs.get('temperature'), - max_tokens=kwargs.get('max_tokens'), - top_p=kwargs.get('top_p'), - tool_choice=kwargs.get('tool_choice', "auto"), - execute_tools_async=kwargs.get('execute_tools_async', True), - system_message=json.dumps(kwargs.get('system_message')), - tools=json.dumps(kwargs.get('tools')), - response_format=json.dumps(kwargs.get('response_format')), - autonomous_iterations_amount=kwargs.get('autonomous_iterations_amount'), - continue_instructions=kwargs.get('continue_instructions') - ) - async with self.db.get_async_session() as session: - session.add(thread_run) - await session.commit() - return thread_run - - async def get_thread_run_count(self, thread_id: str) -> int: - async with self.db.get_async_session() as session: - result = await session.execute(select(ThreadRun).filter_by(thread_id=thread_id)) - return len(result.all()) - - async def get_thread_run_status(self, thread_id: str, run_id: str) -> Dict[str, Any]: - async with self.db.get_async_session() as session: - run = await session.get(ThreadRun, run_id) - if run and run.thread_id == thread_id: - return self.serialize_thread_run(run) - return None - - def serialize_thread_run(self, run: ThreadRun) -> Dict[str, Any]: - return { - "id": run.id, - "thread_id": run.thread_id, - "status": run.status, - "created_at": run.created_at, - "started_at": run.started_at, - "completed_at": run.completed_at, - "cancelled_at": run.cancelled_at, - "failed_at": run.failed_at, - "model": run.model, - "temperature": run.temperature, - "max_tokens": run.max_tokens, - "top_p": run.top_p, - "tool_choice": run.tool_choice, - "execute_tools_async": run.execute_tools_async, - "system_message": json.loads(run.system_message) if run.system_message else None, - "tools": json.loads(run.tools) if run.tools else None, - "usage": run.usage, - "response_format": json.loads(run.response_format) if run.response_format else None, - "last_error": run.last_error, - "autonomous_iterations_amount": run.autonomous_iterations_amount, - "continue_instructions": run.continue_instructions, - "iterations": json.loads(run.iterations) if run.iterations else None - } - - def serialize_usage(self, usage): - return { - "completion_tokens": usage.completion_tokens, - "prompt_tokens": usage.prompt_tokens, - "total_tokens": usage.total_tokens, - "completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details), - "prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details) - } - - def serialize_completion_tokens_details(self, details): - return { - "audio_tokens": details.audio_tokens, - "reasoning_tokens": details.reasoning_tokens - } - - def serialize_prompt_tokens_details(self, details): - return { - "audio_tokens": details.audio_tokens, - "cached_tokens": details.cached_tokens - } - - def serialize_choice(self, choice): - return { - "finish_reason": choice.finish_reason, - "index": choice.index, - "message": self.serialize_message(choice.message) - } - - def serialize_message(self, message): - return { - "content": message.content, - "role": message.role, - "tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None - } - - def serialize_tool_call(self, tool_call): - return { - "id": tool_call.id, - "type": tool_call.type, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments - } - } - if __name__ == "__main__": import asyncio - from agentpress.db import Database from tools.files_tool import FilesTool async def main(): - db = Database() - manager = ThreadManager(db) - + manager = ThreadManager() + + manager.add_tool(FilesTool, ['create_file', 'read_file']) + thread_id = await manager.create_thread() - await manager.add_message(thread_id, {"role": "user", "content": "Let's have a conversation about artificial intelligence and create a file summarizing our discussion."}) + + await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"}) + + system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."} + model_name = "gpt-4o" - system_message = {"role": "system", "content": "You are an AI expert engaging in a conversation about artificial intelligence. You can also create and manage files."} - - files_tool = FilesTool() - tool_schemas = files_tool.get_schemas() + # Test with tools + response_with_tools = await manager.run_thread( + thread_id=thread_id, + system_message=system_message, + model_name=model_name, + temperature=0.7, + max_tokens=150, + tool_choice="auto", + additional_message=None, + execute_tools_async=True, + execute_model_tool_calls=True, + use_tools=True + ) - def initializer(): - print("Initializing thread run...") - manager.run_config['temperature'] = 0.8 + print("Response with tools:", response_with_tools) - def pre_iteration(): - print(f"Preparing iteration {manager.current_iteration}...") - manager.run_config['max_tokens'] = 200 if manager.current_iteration > 3 else 150 + # Test without tools + response_without_tools = await manager.run_thread( + thread_id=thread_id, + system_message=system_message, + model_name=model_name, + temperature=0.7, + max_tokens=150, + additional_message={"role": "user", "content": "What's the capital of France?"}, + use_tools=False + ) - def after_iteration(): - print(f"Completed iteration {manager.current_iteration}. Status: {manager.run_config['status']}") - manager.run_config['continue_instructions'] = "Let's focus more on AI ethics in the next iteration and update our summary file." - - def finalizer(): - print(f"Thread run finished with status: {manager.run_config['status']}") - print(f"Final configuration: {manager.run_config}") - - settings = { - "thread_id": thread_id, - "system_message": system_message, - "model_name": "gpt-4o", - "temperature": 0.7, - "max_tokens": 150, - "autonomous_iterations_amount": 3, - "continue_instructions": "Continue the conversation about AI, introducing new aspects or asking thought-provoking questions. Don't forget to update our summary file.", - "initializer": initializer, - "pre_iteration": pre_iteration, - "after_iteration": after_iteration, - "finalizer": finalizer, - "tools": list(tool_schemas.keys()), - "tool_choice": "auto" - } - - response = await manager.run_thread(settings) - - print(f"Thread run response: {response}") + print("Response without tools:", response_without_tools) + # List messages in the thread messages = await manager.list_messages(thread_id) - print("\nFinal conversation:") + print("\nMessages in the thread:") for msg in messages: print(f"{msg['role'].capitalize()}: {msg['content']}") - asyncio.run(main()) \ No newline at end of file + # Run the async main function + asyncio.run(main()) diff --git a/agentpress/tool.py b/agentpress/tool.py index d2fbee70..86fdada1 100644 --- a/agentpress/tool.py +++ b/agentpress/tool.py @@ -1,23 +1,100 @@ -from datetime import datetime -from typing import List, Dict, Any +""" +This module provides the foundation for creating and managing tools in the AgentPress system. + +The tool system allows for easy creation of function-like tools that can be used by AI models. +It provides a way to define OpenAPI schemas for these tools, which can then be used to generate +appropriate function calls in the AI model's context. + +Key components: +- ToolResult: A dataclass representing the result of a tool execution. +- Tool: An abstract base class that all tools should inherit from. +- tool_schema: A decorator for easily defining OpenAPI schemas for tool methods. + +Usage: +1. Create a new tool by subclassing Tool. +2. Define methods in your tool class and decorate them with @tool_schema. +3. The Tool class will automatically register these schemas. +4. Use the tool in your ThreadManager by adding it with add_tool method. + +Example: + class MyTool(Tool): + @tool_schema({ + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"} + }, + "required": ["a", "b"] + } + }) + async def add(self, a: float, b: float) -> ToolResult: + return self.success_response(f"The sum is {a + b}") + + # In your thread manager: + manager.add_tool(MyTool) +""" + +from typing import Dict, Any from dataclasses import dataclass -from abc import ABC, abstractmethod +from abc import ABC import json +import inspect @dataclass class ToolResult: + """ + Represents the result of a tool execution. + + Attributes: + success (bool): Whether the tool execution was successful. + output (str): The output of the tool execution. + """ success: bool output: str class Tool(ABC): - def __init__(self): - pass + """ + Abstract base class for all tools. + + This class provides the basic structure and functionality for tools. + Subclasses should implement specific tool methods decorated with @tool_schema. + + Methods: + get_schemas(): Returns a dictionary of all registered tool schemas. + success_response(data): Creates a successful ToolResult. + fail_response(msg): Creates a failed ToolResult. + """ + def __init__(self): + self._schemas = {} + self._register_schemas() + + def _register_schemas(self): + """ + Automatically registers schemas for all methods decorated with @tool_schema. + """ + for name, method in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(method, 'schema'): + self._schemas[name] = method.schema - @abstractmethod def get_schemas(self) -> Dict[str, Dict[str, Any]]: - pass + """ + Returns a dictionary of all registered tool schemas, formatted for use with AI models. + """ + return self._schemas def success_response(self, data: Dict[str, Any] | str) -> ToolResult: + """ + Creates a successful ToolResult with the given data. + + Args: + data: The data to include in the success response. + + Returns: + A ToolResult indicating success. + """ if isinstance(data, str): text = data else: @@ -25,10 +102,47 @@ class Tool(ABC): return ToolResult(success=True, output=text) def fail_response(self, msg: str) -> ToolResult: + """ + Creates a failed ToolResult with the given error message. + + Args: + msg: The error message to include in the failure response. + + Returns: + A ToolResult indicating failure. + """ return ToolResult(success=False, output=msg) - def format_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: - return { +def tool_schema(schema: Dict[str, Any]): + """ + A decorator for easily defining OpenAPI schemas for tool methods. + + This decorator allows you to define the schema for a tool method inline with the method definition. + It attaches the provided schema directly to the method. + + Args: + schema (Dict[str, Any]): An OpenAPI schema describing the tool. + + Example: + @tool_schema({ + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"} + }, + "required": ["a", "b"] + } + }) + async def add(self, a: float, b: float) -> ToolResult: + return self.success_response(f"The sum is {a + b}") + """ + def decorator(func): + func.schema = { "type": "function", "function": schema - } \ No newline at end of file + } + return func + return decorator diff --git a/agentpress/tool_registry.py b/agentpress/tool_registry.py index 4b1c1f91..8345c602 100644 --- a/agentpress/tool_registry.py +++ b/agentpress/tool_registry.py @@ -1,4 +1,4 @@ -from typing import Dict, Type, Any, Optional +from typing import Dict, Type, Any, List, Optional from agentpress.tool import Tool from agentpress.config import settings import importlib.util @@ -8,39 +8,34 @@ import inspect class ToolRegistry: def __init__(self): self.tools: Dict[str, Dict[str, Any]] = {} - self.register_all_tools() - def register_tool(self, tool_cls: Type[Tool]): + def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None): tool_instance = tool_cls() schemas = tool_instance.get_schemas() - for func_name, schema in schemas.items(): - self.tools[func_name] = { - "instance": tool_instance, - "schema": schema - } + + if function_names is None: + # Register all functions + for func_name, schema in schemas.items(): + self.tools[func_name] = { + "instance": tool_instance, + "schema": schema + } + else: + # Register only specified functions + for func_name in function_names: + if func_name in schemas: + self.tools[func_name] = { + "instance": tool_instance, + "schema": schemas[func_name] + } + else: + raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}") - def register_all_tools(self): - tools_dir = settings.tools_dir - for file in os.listdir(tools_dir): - if file.endswith('.py') and file not in ['__init__.py', 'tool.py', 'tool_registry.py']: - module_path = os.path.join(tools_dir, file) - module_name = os.path.splitext(file)[0] - try: - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, Tool) and obj != Tool: - print(f"Registering tool: {name}") # Debug print - self.register_tool(obj) - except Exception as e: - print(f"Error importing {module_path}: {e}") # Debug print - - print(f"Registered tools: {list(self.tools.keys())}") # Debug print - - def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]: - return self.tools.get(tool_name) + def get_tool(self, tool_name: str) -> Dict[str, Any]: + return self.tools.get(tool_name, {}) def get_all_tools(self) -> Dict[str, Dict[str, Any]]: - return self.tools \ No newline at end of file + return self.tools + + def get_all_tool_schemas(self) -> List[Dict[str, Any]]: + return [tool_info['schema'] for tool_info in self.tools.values()] diff --git a/agentpress/ui/__init__.py b/agentpress/ui/__init__.py deleted file mode 100644 index 54821b92..00000000 --- a/agentpress/ui/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .main import main -from .tool_display import display_tools - -__all__ = ['main', 'display_tools'] diff --git a/agentpress/ui/main.py b/agentpress/ui/main.py deleted file mode 100644 index f81e64cb..00000000 --- a/agentpress/ui/main.py +++ /dev/null @@ -1,44 +0,0 @@ -import streamlit as st -from agentpress.ui.thread_management import display_thread_management -from agentpress.ui.message_display import display_messages_and_runner -from agentpress.ui.thread_runner import fetch_thread_runs, display_runs -from agentpress.ui.tool_display import display_tools -from agentpress.ui.utils import initialize_session_state, fetch_data, API_BASE_URL - -def main(): - initialize_session_state() - fetch_data() - - st.set_page_config(page_title="AI Assistant Management System", layout="wide") - - st.sidebar.title("Navigation") - mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"]) - - st.title("AI Assistant Management System") - - if mode == "Tools": - display_tools() - else: # Thread Management - display_thread_management_content() - -def display_thread_management_content(): - col1, col2 = st.columns([1, 3]) - - with col1: - display_thread_management() - if st.session_state.selected_thread: - display_thread_runner(st.session_state.selected_thread) - - with col2: - if st.session_state.selected_thread: - display_messages_and_runner(st.session_state.selected_thread) - -def display_thread_runner(thread_id): - - limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20) - if st.button("Fetch Runs"): - runs = fetch_thread_runs(thread_id, limit) - # display_runs(runs) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/agentpress/ui/message_display.py b/agentpress/ui/message_display.py deleted file mode 100644 index 20efc5ff..00000000 --- a/agentpress/ui/message_display.py +++ /dev/null @@ -1,106 +0,0 @@ -import streamlit as st -import requests -from agentpress.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE -from agentpress.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content - -def display_messages_and_runner(thread_id): - st.subheader(f"Messages for Thread: {thread_id}") - - messages_container = st.empty() - - def update_messages(): - messages = fetch_messages(thread_id) - with messages_container.container(): - display_message_list(messages) - - update_messages() - - display_add_message_form(thread_id, update_messages) - display_run_thread_form(thread_id, update_messages) - -def fetch_messages(thread_id): - messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/") - if messages_response.status_code == 200: - return messages_response.json() - else: - st.error("Failed to fetch messages.") - return [] - -def display_message_list(messages): - for msg in messages: - with st.chat_message(msg['role']): - st.write(msg['content']) - -def display_add_message_form(thread_id, update_callback): - st.write("### Add a New Message") - with st.form(key="add_message_form"): - role = st.selectbox("Role", ["user", "assistant"], key="add_role") - content = st.text_area("Content", key="add_content") - submitted = st.form_submit_button("Add Message") - if submitted: - add_message(thread_id, role, content) - update_callback() - -def display_run_thread_form(thread_id, update_callback): - st.write("### Run Thread") - with st.form(key="run_thread_form"): - model_name = st.selectbox("Model", AI_MODELS, key="model_name") - temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature") - max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens") - system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100) - additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100) - - tool_options = list(st.session_state.tools.keys()) - selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools") - - autonomous_iterations_amount = st.number_input("Autonomous Iterations", min_value=1, max_value=10, value=1, key="autonomous_iterations_amount") - continue_instructions = st.text_area("Continue Instructions", key="continue_instructions", height=100, help="Instructions for continuing the conversation in subsequent iterations") - - initializer = st.text_input("Initializer Function", key="initializer") - pre_iteration = st.text_input("Pre-iteration Function", key="pre_iteration") - after_iteration = st.text_input("After-iteration Function", key="after_iteration") - finalizer = st.text_input("Finalizer Function", key="finalizer") - - submitted = st.form_submit_button("Run Thread") - if submitted: - run_thread_data = prepare_run_thread_data( - model_name, temperature, max_tokens, system_message, - additional_system_message, selected_tools, - autonomous_iterations_amount, continue_instructions, - initializer, pre_iteration, after_iteration, finalizer - ) - response = run_thread(thread_id, run_thread_data) - if response: - st.subheader("Thread Run Response") - st.json(response) - - # Update messages - update_callback() - -def add_message(thread_id, role, content): - message_data = {"role": role, "content": content} - add_msg_response = requests.post( - f"{API_BASE_URL}/threads/{thread_id}/messages/", - json=message_data - ) - if add_msg_response.status_code == 200: - st.success("Message added successfully.") - else: - st.error("Failed to add message.") - -def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions, initializer, pre_iteration, after_iteration, finalizer): - return { - "system_message": {"role": "system", "content": system_message}, - "model_name": model_name, - "temperature": temperature, - "max_tokens": max_tokens, - "tools": selected_tools, - "additional_system_message": additional_system_message, - "tool_choice": "auto", - "autonomous_iterations_amount": autonomous_iterations_amount, - "continue_instructions": continue_instructions, - "initializer": initializer, - "pre_iteration": pre_iteration, - "after_iteration": after_iteration, - "finalizer": finalizer - } diff --git a/agentpress/ui/real_time_updates.py b/agentpress/ui/real_time_updates.py deleted file mode 100644 index fe88ddac..00000000 --- a/agentpress/ui/real_time_updates.py +++ /dev/null @@ -1,22 +0,0 @@ -import streamlit as st -import time -import requests -from agentpress.ui.utils import API_BASE_URL - -def get_run_status(thread_id, run_id, is_agent_run): - endpoint = f"agent_runs" if is_agent_run else f"runs" - response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/{endpoint}/{run_id}/status") - if response.status_code == 200: - return response.json() - return None - -def real_time_status_update(thread_id, run_id, is_agent_run): - status_placeholder = st.empty() - while True: - status = get_run_status(thread_id, run_id, is_agent_run) - if status: - status_placeholder.write(f"Current status: {status['status']}") - if status['status'] in ['completed', 'failed', 'cancelled']: - break - time.sleep(1) - return status \ No newline at end of file diff --git a/agentpress/ui/thread_management.py b/agentpress/ui/thread_management.py deleted file mode 100644 index 4f1654b7..00000000 --- a/agentpress/ui/thread_management.py +++ /dev/null @@ -1,87 +0,0 @@ -import streamlit as st -import requests -from agentpress.ui.utils import API_BASE_URL -from datetime import datetime -from agentpress.ui.thread_runner import stop_thread_run, stop_agent_run, get_thread_run_status, get_agent_run_status - -def display_thread_management(): - st.subheader("Thread Management") - - if st.button("➕ Create New Thread", key="create_thread_button"): - create_new_thread() - - display_thread_selector() - - if st.session_state.selected_thread: - display_run_history(st.session_state.selected_thread) - -def create_new_thread(): - response = requests.post(f"{API_BASE_URL}/threads/") - if response.status_code == 200: - thread_id = response.json()['thread_id'] - st.session_state.selected_thread = thread_id - st.success(f"New thread created with ID: {thread_id}") - st.rerun() - else: - st.error("Failed to create a new thread.") - -def display_thread_selector(): - threads_response = requests.get(f"{API_BASE_URL}/threads/") - if threads_response.status_code == 200: - threads = threads_response.json() - - # Sort threads by created_at timestamp (newest first) - sorted_threads = sorted(threads, key=lambda x: x['created_at'], reverse=True) - - thread_options = [f"{thread['thread_id']} - Created: {format_timestamp(thread['created_at'])}" for thread in sorted_threads] - - if st.session_state.selected_thread is None and sorted_threads: - st.session_state.selected_thread = sorted_threads[0]['thread_id'] - - selected_thread = st.selectbox( - "🔍 Select Thread", - thread_options, - key="thread_select", - index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0) - ) - - if selected_thread: - st.session_state.selected_thread = selected_thread.split(' - ')[0] - else: - st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}") - -def display_run_history(thread_id): - st.subheader("Run History") - - # Fetch thread runs - thread_runs = fetch_thread_runs(thread_id) - - # Display thread runs - st.write("### Thread Runs") - for run in thread_runs: - with st.expander(f"Run {run['id']} - Status: {run['status']}"): - st.write(f"Created At: {format_timestamp(run['created_at'])}") - st.write(f"Status: {run['status']}") - st.write(f"Iterations: {len(run.get('iterations', []))} / {run.get('autonomous_iterations_amount', 1)}") - - if run['status'] == "in_progress": - if st.button(f"Stop Run {run['id']}", key=f"stop_thread_run_{run['id']}"): - stop_thread_run(thread_id, run['id']) - st.rerun() - - if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_thread_run_{run['id']}"): - updated_run = get_thread_run_status(thread_id, run['id']) - if updated_run: - run.update(updated_run) - st.rerun() - -def fetch_thread_runs(thread_id): - response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs") - if response.status_code == 200: - return response.json() - else: - st.error("Failed to fetch thread runs.") - return [] - -def format_timestamp(timestamp): - return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') diff --git a/agentpress/ui/thread_runner.py b/agentpress/ui/thread_runner.py deleted file mode 100644 index aa094535..00000000 --- a/agentpress/ui/thread_runner.py +++ /dev/null @@ -1,195 +0,0 @@ -import streamlit as st -import requests -from agentpress.ui.utils import API_BASE_URL -from datetime import datetime - -def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools): - return { - "system_message": {"role": "system", "content": system_message}, - "model_name": model_name, - "temperature": temperature, - "max_tokens": max_tokens, - "tools": selected_tools, - "additional_system_message": additional_system_message, - "tool_choice": "auto" # Add this line to ensure tool_choice is always set - } - -def prepare_run_thread_agent_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions): - return { - "system_message": {"role": "system", "content": system_message}, - "model_name": model_name, - "temperature": temperature, - "max_tokens": max_tokens, - "tools": selected_tools, - "additional_system_message": additional_system_message, - "autonomous_iterations_amount": autonomous_iterations_amount, - "continue_instructions": continue_instructions - } - -def run_thread(thread_id, run_thread_data): - with st.spinner("Running thread..."): - try: - run_thread_response = requests.post( - f"{API_BASE_URL}/threads/{thread_id}/run/", - json=run_thread_data - ) - run_thread_response.raise_for_status() - response_data = run_thread_response.json() - st.success(f"Thread run completed successfully! Status: {response_data.get('status', 'Unknown')}") - - if 'id' in response_data: - st.session_state.latest_run_id = response_data['id'] - - st.subheader("Response Content") - display_response_content(response_data) - - # Display the full response data - st.subheader("Full Response Data") - st.json(response_data) - - return response_data # Return the response data - except requests.exceptions.RequestException as e: - st.error(f"Failed to run thread. Error: {str(e)}") - if hasattr(e, 'response') and e.response is not None: - st.text("Response content:") - st.text(e.response.text) - except Exception as e: - st.error(f"An unexpected error occurred: {str(e)}") - - return None # Return None if there was an error - -def run_thread_agent(thread_id, run_thread_agent_data): - with st.spinner("Running thread agent..."): - try: - run_thread_response = requests.post( - f"{API_BASE_URL}/threads/{thread_id}/run_agent/", - json=run_thread_agent_data - ) - run_thread_response.raise_for_status() - response_data = run_thread_response.json() - st.success(f"Thread agent run completed successfully! Status: {response_data['status']}") - - st.subheader("Agent Response") - display_agent_response_content(response_data) - - # Display the full response data - st.subheader("Full Agent Response Data") - st.json(response_data) - - return response_data - except requests.exceptions.RequestException as e: - st.error(f"Failed to run thread agent. Error: {str(e)}") - if hasattr(e, 'response') and e.response is not None: - st.text("Response content:") - st.text(e.response.text) - except Exception as e: - st.error(f"An unexpected error occurred: {str(e)}") - -def display_response_content(response_data): - if isinstance(response_data, dict) and 'choices' in response_data: - message = response_data['choices'][0]['message'] - st.write(f"**Role:** {message['role']}") - st.write(f"**Content:** {message['content']}") - - if 'tool_calls' in message and message['tool_calls']: - st.write("**Tool Calls:**") - for tool_call in message['tool_calls']: - st.write(f"- Function: `{tool_call['function']['name']}`") - st.code(tool_call['function']['arguments'], language="json") - else: - st.json(response_data) - -def display_agent_response_content(response_data): - st.write(f"**Status:** {response_data['status']}") - st.write(f"**Total Iterations:** {response_data['total_iterations']}") - st.write(f"**Completed Iterations:** {response_data.get('iterations_count', 'N/A')}") - - for i, iteration in enumerate(response_data['iterations']): - with st.expander(f"Iteration {i+1}"): - display_response_content(iteration) - - st.write("**Final Configuration:**") - st.json(response_data['final_config']) - -def fetch_thread_runs(thread_id, limit): - response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}") - if response.status_code == 200: - return response.json() - else: - st.error("Failed to retrieve runs.") - return [] - -def format_timestamp(timestamp): - if timestamp: - return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') - return 'N/A' - -def display_runs(runs): - for run in runs: - with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False): - col1, col2 = st.columns(2) - with col1: - st.write(f"**Created At:** {format_timestamp(run['created_at'])}") - st.write(f"**Started At:** {format_timestamp(run['started_at'])}") - st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}") - st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}") - st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}") - with col2: - st.write(f"**Model:** {run['model']}") - st.write(f"**Temperature:** {run['temperature']}") - st.write(f"**Top P:** {run['top_p']}") - st.write(f"**Max Tokens:** {run['max_tokens']}") - st.write(f"**Tool Choice:** {run['tool_choice']}") - st.write(f"**Execute Tools Async:** {run['execute_tools_async']}") - st.write(f"**Autonomous Iterations:** {run['autonomous_iterations_amount']}") - - st.write("**System Message:**") - st.json(run['system_message']) - - if run['tools']: - st.write("**Tools:**") - st.json(run['tools']) - - if run['usage']: - st.write("**Usage:**") - st.json(run['usage']) - - if run['response_format']: - st.write("**Response Format:**") - st.json(run['response_format']) - - if run['last_error']: - st.error("**Last Error:**") - st.code(run['last_error']) - - if run['continue_instructions']: - st.write("**Continue Instructions:**") - st.text(run['continue_instructions']) - - if run['status'] == "in_progress": - if st.button(f"Stop Run {run['id']}", key=f"stop_button_{run['id']}"): - stop_thread_run(run['thread_id'], run['id']) - st.rerun() - - if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_button_{run['id']}"): - updated_run = get_thread_run_status(run['thread_id'], run['id']) - if updated_run: - run.update(updated_run) - st.rerun() - -def stop_thread_run(thread_id, run_id): - response = requests.post(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/stop") - if response.status_code == 200: - st.success("Thread run stopped successfully.") - return response.json() - else: - st.error(f"Failed to stop thread run. Status code: {response.status_code}") - return None - -def get_thread_run_status(thread_id, run_id): - response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/status") - if response.status_code == 200: - return response.json() - else: - st.error(f"Failed to get thread run status. Status code: {response.status_code}") - return None diff --git a/agentpress/ui/tool_display.py b/agentpress/ui/tool_display.py deleted file mode 100644 index eaa5eed0..00000000 --- a/agentpress/ui/tool_display.py +++ /dev/null @@ -1,43 +0,0 @@ -import streamlit as st -import requests -from agentpress.ui.utils import API_BASE_URL - -def display_tools(): - st.header("Available Tools") - - tools = fetch_tools() - - if not tools: - st.warning("No tools available. Please check the API connection.") - return - - view_mode = st.radio("View Mode", ["Simple", "Detailed"]) - - if view_mode == "Simple": - display_simple_view(tools) - else: - display_detailed_view(tools) - -def display_simple_view(tools): - for tool_name, tool_info in tools.items(): - with st.expander(f"🛠️ {tool_name}"): - st.write(f"**Description:** {tool_info['description']}") - -def display_detailed_view(tools): - for tool_name, tool_info in tools.items(): - with st.expander(f"🛠️ {tool_name}"): - st.write(f"**Description:** {tool_info['description']}") - if tool_info['schema']: - st.write("**Schema:**") - st.json(tool_info['schema']) - -def fetch_tools(): - response = requests.get(f"{API_BASE_URL}/tools/") - if response.status_code == 200: - tools = response.json() - st.session_state.tools = tools - return tools - else: - st.error(f"Failed to fetch tools. Status code: {response.status_code}") - st.error(f"Error message: {response.text}") - return {} \ No newline at end of file diff --git a/agentpress/ui/utils.py b/agentpress/ui/utils.py deleted file mode 100644 index 63570434..00000000 --- a/agentpress/ui/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import streamlit as st -import requests -from agentpress.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE - -API_BASE_URL = "http://localhost:8000" - -def fetch_tools(): - response = requests.get(f"{API_BASE_URL}/tools/") - if response.status_code == 200: - st.session_state.tools = response.json() - else: - st.error("Failed to fetch tools.") - -def initialize_session_state(): - if 'selected_thread' not in st.session_state: - st.session_state.selected_thread = None - if 'tools' not in st.session_state: - st.session_state.tools = [] - if 'fetch_tools' not in st.session_state: - st.session_state.fetch_tools = fetch_tools - -def fetch_data(): - fetch_tools() \ No newline at end of file