diff --git a/CHANGELOG.md b/CHANGELOG.md index f229936e..029717a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +0.1.9 +- Enhanced storage capabilities for production readiness: + - Added SQLite as primary storage backend for ThreadManager and StateManager + - Implemented persistent storage with unique store IDs + - Added CRUD operations for state management + - Enabled multiple concurrent stores with referential integrity + - Improved state persistence and retrieval mechanisms + 0.1.8 - Added base processor classes for extensible tool handling: - ToolParserBase: Abstract base class for parsing LLM responses diff --git a/agentpress.db b/agentpress.db new file mode 100644 index 00000000..42d82dcc Binary files /dev/null and b/agentpress.db differ diff --git a/agentpress/agents/simple_web_dev/agent.py b/agentpress/agents/simple_web_dev/agent.py index 54dd3486..5300ed74 100644 --- a/agentpress/agents/simple_web_dev/agent.py +++ b/agentpress/agents/simple_web_dev/agent.py @@ -91,12 +91,13 @@ file contents here async def run_agent(thread_id: str, use_xml: bool = True, max_iterations: int = 5): """Run the development agent with specified configuration.""" thread_manager = ThreadManager() - state_manager = StateManager() - - thread_manager.add_tool(FilesTool) - thread_manager.add_tool(TerminalTool) - # Combine base message with XML format if needed + store_id = await StateManager.create_store() + state_manager = StateManager(store_id) + + thread_manager.add_tool(FilesTool, store_id=store_id) + thread_manager.add_tool(TerminalTool, store_id=store_id) + system_message = { "role": "system", "content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "") @@ -199,6 +200,7 @@ def main(): async def async_main(): thread_manager = ThreadManager() + thread_id = await thread_manager.create_thread() await thread_manager.add_message( thread_id, diff --git a/agentpress/agents/simple_web_dev/tools/files_tool.py b/agentpress/agents/simple_web_dev/tools/files_tool.py index c27159b6..d87ae9b9 100644 --- a/agentpress/agents/simple_web_dev/tools/files_tool.py +++ b/agentpress/agents/simple_web_dev/tools/files_tool.py @@ -1,8 +1,9 @@ import os import asyncio from pathlib import Path -from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema +from agentpress.tools.tool import Tool, ToolResult, openapi_schema, xml_schema from agentpress.state_manager import StateManager +from typing import Optional class FilesTool(Tool): """File management tool for creating, updating, and deleting files. @@ -53,11 +54,11 @@ class FilesTool(Tool): ".sql" } - def __init__(self): + def __init__(self, store_id: Optional[str] = None): super().__init__() self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace') os.makedirs(self.workspace, exist_ok=True) - self.state_manager = StateManager("state.json") + self.state_manager = StateManager(store_id) self.SNIPPET_LINES = 4 # Number of context lines to show around edits asyncio.create_task(self._init_workspace_state()) diff --git a/agentpress/agents/simple_web_dev/tools/terminal_tool.py b/agentpress/agents/simple_web_dev/tools/terminal_tool.py index 5bd7eb77..616093fc 100644 --- a/agentpress/agents/simple_web_dev/tools/terminal_tool.py +++ b/agentpress/agents/simple_web_dev/tools/terminal_tool.py @@ -3,15 +3,16 @@ import asyncio import subprocess from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema from agentpress.state_manager import StateManager +from typing import Optional class TerminalTool(Tool): """Terminal command execution tool for workspace operations.""" - def __init__(self): + def __init__(self, store_id: Optional[str] = None): super().__init__() self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace') os.makedirs(self.workspace, exist_ok=True) - self.state_manager = StateManager("state.json") + self.state_manager = StateManager(store_id) async def _update_command_history(self, command: str, output: str, success: bool): """Update command history in state""" diff --git a/agentpress/cli.py b/agentpress/cli.py index 5cb94cb5..24ee9911 100644 --- a/agentpress/cli.py +++ b/agentpress/cli.py @@ -26,14 +26,14 @@ MODULES = { "processors": { "required": True, "files": [ - "base_processors.py", - "llm_response_processor.py", - "standard_tool_parser.py", - "standard_tool_executor.py", - "standard_results_adder.py", - "xml_tool_parser.py", - "xml_tool_executor.py", - "xml_results_adder.py" + "processor/base_processors.py", + "processor/llm_response_processor.py", + "processor/standard/standard_tool_parser.py", + "processor/standard/standard_tool_executor.py", + "processor/standard/standard_results_adder.py", + "processor/xml/xml_tool_parser.py", + "processor/xml/xml_tool_executor.py", + "processor/xml/xml_results_adder.py" ], "description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns." }, diff --git a/agentpress/db_connection.py b/agentpress/db_connection.py new file mode 100644 index 00000000..c173c815 --- /dev/null +++ b/agentpress/db_connection.py @@ -0,0 +1,125 @@ +""" +Centralized database connection management for AgentPress. +""" + +import aiosqlite +import logging +from contextlib import asynccontextmanager +import os +import asyncio + +class DBConnection: + """Singleton database connection manager.""" + + _instance = None + _initialized = False + _db_path = os.path.join(os.getcwd(), "agentpress.db") + _init_lock = asyncio.Lock() + _initialization_task = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + # Start initialization when instance is first created + cls._initialization_task = asyncio.create_task(cls._instance._initialize()) + return cls._instance + + def __init__(self): + """No initialization needed in __init__ as it's handled in __new__""" + pass + + @classmethod + async def _initialize(cls): + """Internal initialization method.""" + if cls._initialized: + return + + async with cls._init_lock: + if cls._initialized: # Double-check after acquiring lock + return + + try: + async with aiosqlite.connect(cls._db_path) as db: + # Threads table + await db.execute(""" + CREATE TABLE IF NOT EXISTS threads ( + thread_id TEXT PRIMARY KEY, + messages TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # State stores table + await db.execute(""" + CREATE TABLE IF NOT EXISTS state_stores ( + store_id TEXT PRIMARY KEY, + store_data TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + await db.commit() + cls._initialized = True + logging.info("Database schema initialized") + except Exception as e: + logging.error(f"Database initialization error: {e}") + raise + + @classmethod + def set_db_path(cls, db_path: str): + """Set custom database path.""" + if cls._initialized: + raise RuntimeError("Cannot change database path after initialization") + cls._db_path = db_path + logging.info(f"Updated database path to: {db_path}") + + @asynccontextmanager + async def connection(self): + """Get a database connection.""" + # Wait for initialization to complete if it hasn't already + if self._initialization_task and not self._initialized: + await self._initialization_task + + async with aiosqlite.connect(self._db_path) as conn: + try: + yield conn + except Exception as e: + logging.error(f"Database error: {e}") + raise + + @asynccontextmanager + async def transaction(self): + """Execute operations in a transaction.""" + async with self.connection() as db: + try: + yield db + await db.commit() + except Exception as e: + await db.rollback() + logging.error(f"Transaction error: {e}") + raise + + async def execute(self, query: str, params: tuple = ()): + """Execute a single query.""" + async with self.connection() as db: + try: + result = await db.execute(query, params) + await db.commit() + return result + except Exception as e: + logging.error(f"Query execution error: {e}") + raise + + async def fetch_one(self, query: str, params: tuple = ()): + """Fetch a single row.""" + async with self.connection() as db: + async with db.execute(query, params) as cursor: + return await cursor.fetchone() + + async def fetch_all(self, query: str, params: tuple = ()): + """Fetch all rows.""" + async with self.connection() as db: + async with db.execute(query, params) as cursor: + return await cursor.fetchall() \ No newline at end of file diff --git a/agentpress/base_processors.py b/agentpress/processor/base_processors.py similarity index 98% rename from agentpress/base_processors.py rename to agentpress/processor/base_processors.py index 85f6eea4..ec724180 100644 --- a/agentpress/base_processors.py +++ b/agentpress/processor/base_processors.py @@ -172,7 +172,7 @@ class ResultsAdderBase(ABC): Attributes: add_message: Callback for adding new messages update_message: Callback for updating existing messages - list_messages: Callback for retrieving thread messages + get_messages: Callback for retrieving thread messages message_added: Flag tracking if initial message has been added """ @@ -184,7 +184,7 @@ class ResultsAdderBase(ABC): """ self.add_message = thread_manager.add_message self.update_message = thread_manager._update_message - self.list_messages = thread_manager.list_messages + self.get_messages = thread_manager.get_messages self.message_added = False @abstractmethod diff --git a/agentpress/llm_response_processor.py b/agentpress/processor/llm_response_processor.py similarity index 93% rename from agentpress/llm_response_processor.py rename to agentpress/processor/llm_response_processor.py index 5bf6bc7a..24572210 100644 --- a/agentpress/llm_response_processor.py +++ b/agentpress/processor/llm_response_processor.py @@ -11,10 +11,10 @@ This module provides comprehensive processing of LLM responses, including: import asyncio from typing import Callable, Dict, Any, AsyncGenerator, Optional import logging -from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase -from agentpress.standard_tool_parser import StandardToolParser -from agentpress.standard_tool_executor import StandardToolExecutor -from agentpress.standard_results_adder import StandardResultsAdder +from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase +from agentpress.processor.standard.standard_tool_parser import StandardToolParser +from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor +from agentpress.processor.standard.standard_results_adder import StandardResultsAdder class LLMResponseProcessor: """Handles LLM response processing and tool execution management. @@ -40,9 +40,8 @@ class LLMResponseProcessor: available_functions: Dict = None, add_message_callback: Callable = None, update_message_callback: Callable = None, - list_messages_callback: Callable = None, + get_messages_callback: Callable = None, parallel_tool_execution: bool = True, - threads_dir: str = "threads", tool_parser: Optional[ToolParserBase] = None, tool_executor: Optional[ToolExecutorBase] = None, results_adder: Optional[ResultsAdderBase] = None, @@ -55,9 +54,8 @@ class LLMResponseProcessor: available_functions: Dictionary of available tool functions add_message_callback: Callback for adding messages update_message_callback: Callback for updating messages - list_messages_callback: Callback for listing messages + get_messages_callback: Callback for listing messages parallel_tool_execution: Whether to execute tools in parallel - threads_dir: Directory for thread storage tool_parser: Custom tool parser implementation tool_executor: Custom tool executor implementation results_adder: Custom results adder implementation @@ -67,16 +65,15 @@ class LLMResponseProcessor: self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution) self.tool_parser = tool_parser or StandardToolParser() self.available_functions = available_functions or {} - self.threads_dir = threads_dir # Create minimal thread manager if needed - if thread_manager is None and (add_message_callback and update_message_callback and list_messages_callback): + if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback): class MinimalThreadManager: def __init__(self, add_msg, update_msg, list_msg): self.add_message = add_msg self._update_message = update_msg - self.list_messages = list_msg - thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, list_messages_callback) + self.get_messages = list_msg + thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback) self.results_adder = results_adder or StandardResultsAdder(thread_manager) diff --git a/agentpress/standard_results_adder.py b/agentpress/processor/standard/standard_results_adder.py similarity index 96% rename from agentpress/standard_results_adder.py rename to agentpress/processor/standard/standard_results_adder.py index 8862bab8..49d08b64 100644 --- a/agentpress/standard_results_adder.py +++ b/agentpress/processor/standard/standard_results_adder.py @@ -1,5 +1,5 @@ from typing import Dict, Any, List, Optional -from agentpress.base_processors import ResultsAdderBase +from agentpress.processor.base_processors import ResultsAdderBase # --- Standard Results Adder Implementation --- @@ -81,6 +81,6 @@ class StandardResultsAdder(ResultsAdderBase): - Checks for duplicate tool results before adding - Adds result only if tool_call_id is unique """ - messages = await self.list_messages(thread_id) + messages = await self.get_messages(thread_id) if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages): await self.add_message(thread_id, result) diff --git a/agentpress/standard_tool_executor.py b/agentpress/processor/standard/standard_tool_executor.py similarity index 99% rename from agentpress/standard_tool_executor.py rename to agentpress/processor/standard/standard_tool_executor.py index e70cd966..575b9d72 100644 --- a/agentpress/standard_tool_executor.py +++ b/agentpress/processor/standard/standard_tool_executor.py @@ -9,7 +9,7 @@ import asyncio import json import logging from typing import Dict, Any, List, Set, Callable, Optional -from agentpress.base_processors import ToolExecutorBase +from agentpress.processor.base_processors import ToolExecutorBase from agentpress.tool import ToolResult # --- Standard Tool Executor Implementation --- diff --git a/agentpress/standard_tool_parser.py b/agentpress/processor/standard/standard_tool_parser.py similarity index 98% rename from agentpress/standard_tool_parser.py rename to agentpress/processor/standard/standard_tool_parser.py index 9ca551b9..ae024462 100644 --- a/agentpress/standard_tool_parser.py +++ b/agentpress/processor/standard/standard_tool_parser.py @@ -1,6 +1,6 @@ import json from typing import Dict, Any, Optional -from agentpress.base_processors import ToolParserBase +from agentpress.processor.base_processors import ToolParserBase # --- Standard Tool Parser Implementation --- diff --git a/agentpress/xml_results_adder.py b/agentpress/processor/xml/xml_results_adder.py similarity index 96% rename from agentpress/xml_results_adder.py rename to agentpress/processor/xml/xml_results_adder.py index 97822266..49593da8 100644 --- a/agentpress/xml_results_adder.py +++ b/agentpress/processor/xml/xml_results_adder.py @@ -1,6 +1,6 @@ import logging from typing import Dict, Any, List, Optional -from agentpress.base_processors import ResultsAdderBase +from agentpress.processor.base_processors import ResultsAdderBase class XMLResultsAdder(ResultsAdderBase): """XML-specific implementation for handling tool results and message processing. @@ -79,7 +79,7 @@ class XMLResultsAdder(ResultsAdderBase): """ try: # Get the original tool call to find the root tag - messages = await self.list_messages(thread_id) + messages = await self.get_messages(thread_id) assistant_msg = next((msg for msg in reversed(messages) if msg['role'] == 'assistant'), None) diff --git a/agentpress/xml_tool_executor.py b/agentpress/processor/xml/xml_tool_executor.py similarity index 98% rename from agentpress/xml_tool_executor.py rename to agentpress/processor/xml/xml_tool_executor.py index 1cf701bd..185e8e6c 100644 --- a/agentpress/xml_tool_executor.py +++ b/agentpress/processor/xml/xml_tool_executor.py @@ -9,7 +9,7 @@ from typing import List, Dict, Any, Set, Callable, Optional import asyncio import json import logging -from agentpress.base_processors import ToolExecutorBase +from agentpress.processor.base_processors import ToolExecutorBase from agentpress.tool import ToolResult from agentpress.tool_registry import ToolRegistry diff --git a/agentpress/xml_tool_parser.py b/agentpress/processor/xml/xml_tool_parser.py similarity index 99% rename from agentpress/xml_tool_parser.py rename to agentpress/processor/xml/xml_tool_parser.py index 8942a4ae..890c00aa 100644 --- a/agentpress/xml_tool_parser.py +++ b/agentpress/processor/xml/xml_tool_parser.py @@ -7,7 +7,7 @@ complete and streaming responses with robust XML parsing and validation capabili import logging from typing import Dict, Any, Optional, List, Tuple -from agentpress.base_processors import ToolParserBase +from agentpress.processor.base_processors import ToolParserBase import json import re from agentpress.tool_registry import ToolRegistry diff --git a/agentpress/state_manager.py b/agentpress/state_manager.py index c822e862..a3b88ad7 100644 --- a/agentpress/state_manager.py +++ b/agentpress/state_manager.py @@ -1,76 +1,100 @@ import json -import os import logging -from typing import Any +from typing import Any, Optional, List, Dict, Union, AsyncGenerator from asyncio import Lock from contextlib import asynccontextmanager +import uuid +from agentpress.db_connection import DBConnection +import asyncio class StateManager: """ Manages persistent state storage for AgentPress components. - The StateManager provides thread-safe access to a JSON-based state store, - allowing components to save and retrieve data across sessions. It handles - concurrent access using asyncio locks and provides atomic operations for - state modifications. + The StateManager provides thread-safe access to a SQLite-based state store, + allowing components to save and retrieve data across sessions. Each store + has a unique ID and contains multiple key-value pairs in a single JSON object. Attributes: lock (Lock): Asyncio lock for thread-safe state access - store_file (str): Path to the JSON file storing the state + db (DBConnection): Database connection manager + store_id (str): Unique identifier for this state store """ - def __init__(self, store_file: str = "state.json"): + def __init__(self, store_id: Optional[str] = None): """ - Initialize StateManager with custom store file name. + Initialize StateManager with optional store ID. Args: - store_file (str): Path to the JSON file to store state. - Defaults to "state.json" in the current directory. + store_id (str, optional): Unique identifier for the store. If None, creates new. """ self.lock = Lock() - self.store_file = store_file - logging.info(f"StateManager initialized with store file: {store_file}") + self.db = DBConnection() + self.store_id = store_id or str(uuid.uuid4()) + logging.info(f"StateManager initialized with store_id: {self.store_id}") + asyncio.create_task(self._ensure_store_exists()) + + @classmethod + async def create_store(cls) -> str: + """Create a new state store and return its ID.""" + store_id = str(uuid.uuid4()) + manager = cls(store_id) + await manager._ensure_store_exists() + return store_id + + async def _ensure_store_exists(self): + """Ensure store exists in database.""" + async with self.db.transaction() as conn: + await conn.execute(""" + INSERT OR IGNORE INTO state_stores (store_id, store_data) + VALUES (?, ?) + """, (self.store_id, json.dumps({}))) @asynccontextmanager async def store_scope(self): """ Context manager for atomic state operations. - Provides thread-safe access to the state store, handling file I/O - and ensuring proper cleanup. Automatically loads the current state - and saves changes when the context exits. + Provides thread-safe access to the state store, handling database + operations and ensuring proper cleanup. Yields: dict: The current state store contents Raises: - Exception: If there are errors reading from or writing to the store file + Exception: If there are errors with database operations """ - try: - # Read current state - if os.path.exists(self.store_file): - with open(self.store_file, 'r') as f: - store = json.load(f) - else: - store = {} - - yield store - - # Write updated state - with open(self.store_file, 'w') as f: - json.dump(store, f, indent=2) - logging.debug("Store saved successfully") - except Exception as e: - logging.error("Error in store operation", exc_info=True) - raise + async with self.lock: + try: + async with self.db.transaction() as conn: + async with conn.execute( + "SELECT store_data FROM state_stores WHERE store_id = ?", + (self.store_id,) + ) as cursor: + row = await cursor.fetchone() + store = json.loads(row[0]) if row else {} + + yield store + + await conn.execute( + """ + UPDATE state_stores + SET store_data = ?, updated_at = CURRENT_TIMESTAMP + WHERE store_id = ? + """, + (json.dumps(store), self.store_id) + ) + except Exception as e: + logging.error("Error in store operation", exc_info=True) + raise - async def set(self, key: str, data: Any): + async def set(self, key: str, data: Any) -> Any: """ - Store any JSON-serializable data with a simple key. + Store any JSON-serializable data with a key. Args: key (str): Simple string key like "config" or "settings" - data (Any): Any JSON-serializable data (dict, list, str, int, bool, etc) + data (Any): Any JSON-serializable data Returns: Any: The stored data @@ -78,17 +102,12 @@ class StateManager: Raises: Exception: If there are errors during storage operation """ - async with self.lock: - async with self.store_scope() as store: - try: - store[key] = data # Will be JSON serialized when written to file - logging.info(f'Updated store key: {key}') - return data - except Exception as e: - logging.error(f'Error in set: {str(e)}') - raise + async with self.store_scope() as store: + store[key] = data + logging.info(f'Updated store key: {key}') + return data - async def get(self, key: str) -> Any: + async def get(self, key: str) -> Optional[Any]: """ Get data for a key. @@ -97,9 +116,6 @@ class StateManager: Returns: Any: The stored data for the key, or None if key not found - - Note: - This operation is read-only and doesn't require locking """ async with self.store_scope() as store: if key in store: @@ -115,17 +131,31 @@ class StateManager: Args: key (str): Simple string key like "config" or "settings" - - Note: - No error is raised if the key doesn't exist """ - async with self.lock: - async with self.store_scope() as store: - if key in store: - del store[key] - logging.info(f"Deleted key: {key}") - else: - logging.info(f"Key not found for deletion: {key}") + async with self.store_scope() as store: + if key in store: + del store[key] + logging.info(f"Deleted key: {key}") + + async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]: + """Update existing data for a key by merging dictionaries.""" + async with self.store_scope() as store: + if key in store and isinstance(store[key], dict): + store[key].update(data) + logging.info(f'Updated store key: {key}') + return store[key] + return None + + async def append(self, key: str, item: Any) -> Optional[List[Any]]: + """Append an item to a list stored at key.""" + async with self.store_scope() as store: + if key not in store: + store[key] = [] + if isinstance(store[key], list): + store[key].append(item) + logging.info(f'Appended to key: {key}') + return store[key] + return None async def export_store(self) -> dict: """ @@ -133,9 +163,6 @@ class StateManager: Returns: dict: Complete contents of the state store - - Note: - This operation is read-only and returns a copy of the store """ async with self.store_scope() as store: logging.info(f"Store content: {store}") @@ -148,7 +175,29 @@ class StateManager: Removes all data from the store, resetting it to an empty state. This operation is atomic and thread-safe. """ - async with self.lock: - async with self.store_scope() as store: - store.clear() - logging.info("Cleared store") + async with self.store_scope() as store: + store.clear() + logging.info("Cleared store") + + @classmethod + async def list_stores(cls) -> List[Dict[str, Any]]: + """ + List all available state stores. + + Returns: + List of store information including IDs and timestamps + """ + db = DBConnection() + async with db.transaction() as conn: + async with conn.execute( + "SELECT store_id, created_at, updated_at FROM state_stores ORDER BY updated_at DESC" + ) as cursor: + stores = [ + { + "store_id": row[0], + "created_at": row[1], + "updated_at": row[2] + } + for row in await cursor.fetchall() + ] + return stores diff --git a/agentpress/thread_manager.py b/agentpress/thread_manager.py index 7a5c7802..61eac089 100644 --- a/agentpress/thread_manager.py +++ b/agentpress/thread_manager.py @@ -11,21 +11,22 @@ This module provides comprehensive conversation management, including: import json import logging -import os +import asyncio import uuid from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator from agentpress.llm import make_llm_api_call from agentpress.tool import Tool, ToolResult from agentpress.tool_registry import ToolRegistry -from agentpress.llm_response_processor import LLMResponseProcessor -from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase +from agentpress.processor.llm_response_processor import LLMResponseProcessor +from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase +from agentpress.db_connection import DBConnection -from agentpress.xml_tool_parser import XMLToolParser -from agentpress.xml_tool_executor import XMLToolExecutor -from agentpress.xml_results_adder import XMLResultsAdder -from agentpress.standard_tool_parser import StandardToolParser -from agentpress.standard_tool_executor import StandardToolExecutor -from agentpress.standard_results_adder import StandardResultsAdder +from agentpress.processor.xml.xml_tool_parser import XMLToolParser +from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor +from agentpress.processor.xml.xml_results_adder import XMLResultsAdder +from agentpress.processor.standard.standard_tool_parser import StandardToolParser +from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor +from agentpress.processor.standard.standard_results_adder import StandardResultsAdder class ThreadManager: """Manages conversation threads with LLM models and tool execution. @@ -33,204 +34,163 @@ class ThreadManager: Provides comprehensive conversation management, handling message threading, tool registration, and LLM interactions with support for both standard and XML-based tool execution patterns. - - Attributes: - threads_dir (str): Directory for storing thread files - tool_registry (ToolRegistry): Registry for managing available tools - - Methods: - add_tool: Register a tool with optional function filtering - create_thread: Create a new conversation thread - add_message: Add a message to a thread - list_messages: Retrieve messages from a thread - run_thread: Execute a conversation thread with LLM """ - def __init__(self, threads_dir: str = "threads"): - """Initialize ThreadManager. - - Args: - threads_dir: Directory to store thread files - - Notes: - Creates the threads directory if it doesn't exist - """ - self.threads_dir = threads_dir + def __init__(self): + """Initialize ThreadManager.""" + self.db = DBConnection() self.tool_registry = ToolRegistry() - os.makedirs(self.threads_dir, exist_ok=True) def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): - """Add a tool to the ThreadManager. - - Args: - tool_class: The tool class to register - function_names: Optional list of specific functions to register - **kwargs: Additional arguments passed to tool initialization - - Notes: - - If function_names is None, all functions are registered - - Tool instances are created with provided kwargs - """ + """Add a tool to the ThreadManager.""" self.tool_registry.register_tool(tool_class, function_names, **kwargs) async def create_thread(self) -> str: - """Create a new conversation thread. - - Returns: - str: Unique thread ID for the created thread - - Raises: - IOError: If thread file creation fails - - Notes: - Creates a new thread file with an empty messages list - """ + """Create a new conversation thread.""" thread_id = str(uuid.uuid4()) - thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") - with open(thread_path, 'w') as f: - json.dump({"messages": []}, f) + await self.db.execute( + "INSERT INTO threads (thread_id, messages) VALUES (?, ?)", + (thread_id, json.dumps([])) + ) return thread_id async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None): - """Add a message to an existing thread. - - Args: - thread_id: ID of the target thread - message_data: Message content and metadata - images: Optional list of image data dictionaries - - Raises: - FileNotFoundError: If thread doesn't exist - Exception: For other operation failures - - Notes: - - Handles cleanup of incomplete tool calls - - Supports both text and image content - - Converts ToolResult instances to strings - """ + """Add a message to an existing thread.""" logging.info(f"Adding message to thread {thread_id} with images: {images}") - thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") try: - with open(thread_path, 'r') as f: - thread_data = json.load(f) - - messages = thread_data["messages"] - - # Handle cleanup of incomplete tool calls - if message_data['role'] == 'user': - last_assistant_index = next((i for i in reversed(range(len(messages))) - if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None) - - if last_assistant_index is not None: - tool_call_count = len(messages[last_assistant_index]['tool_calls']) - tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] - if msg['role'] == 'tool') + async with self.db.transaction() as conn: + # Handle cleanup of incomplete tool calls + if message_data['role'] == 'user': + messages = await self.get_messages(thread_id) + last_assistant_index = next((i for i in reversed(range(len(messages))) + if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None) - if tool_call_count != tool_response_count: - await self.cleanup_incomplete_tool_calls(thread_id) + if last_assistant_index is not None: + tool_call_count = len(messages[last_assistant_index]['tool_calls']) + tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] + if msg['role'] == 'tool') + + if tool_call_count != tool_response_count: + await self.cleanup_incomplete_tool_calls(thread_id) - # Convert ToolResult instances to strings - for key, value in message_data.items(): - if isinstance(value, ToolResult): - message_data[key] = str(value) + # Convert ToolResult instances to strings + for key, value in message_data.items(): + if isinstance(value, ToolResult): + message_data[key] = str(value) - # Handle image attachments - if images: - if isinstance(message_data['content'], str): - message_data['content'] = [{"type": "text", "text": message_data['content']}] - elif not isinstance(message_data['content'], list): - message_data['content'] = [] + # Handle image attachments + if images: + if isinstance(message_data['content'], str): + message_data['content'] = [{"type": "text", "text": message_data['content']}] + elif not isinstance(message_data['content'], list): + message_data['content'] = [] - for image in images: - image_content = { - "type": "image_url", - "image_url": { - "url": f"data:{image['content_type']};base64,{image['base64']}", - "detail": "high" + for image in images: + image_content = { + "type": "image_url", + "image_url": { + "url": f"data:{image['content_type']};base64,{image['base64']}", + "detail": "high" + } } - } - message_data['content'].append(image_content) + message_data['content'].append(image_content) - messages.append(message_data) - thread_data["messages"] = messages + # Get current messages + row = await self.db.fetch_one( + "SELECT messages FROM threads WHERE thread_id = ?", + (thread_id,) + ) + if not row: + raise ValueError(f"Thread {thread_id} not found") + + messages = json.loads(row[0]) + messages.append(message_data) + + # Update thread + await conn.execute( + """ + UPDATE threads + SET messages = ?, updated_at = CURRENT_TIMESTAMP + WHERE thread_id = ? + """, + (json.dumps(messages), thread_id) + ) + + logging.info(f"Message added to thread {thread_id}: {message_data}") - with open(thread_path, 'w') as f: - json.dump(thread_data, f) - - logging.info(f"Message added to thread {thread_id}: {message_data}") except Exception as e: logging.error(f"Failed to add message to thread {thread_id}: {e}") raise e - async def list_messages( + async def get_messages( self, - thread_id: str, - hide_tool_msgs: bool = False, - only_latest_assistant: bool = False, + thread_id: str, + hide_tool_msgs: bool = False, + only_latest_assistant: bool = False, regular_list: bool = True ) -> List[Dict[str, Any]]: - """Retrieve messages from a thread with optional filtering. - - Args: - thread_id: ID of the thread to retrieve messages from - hide_tool_msgs: If True, excludes tool messages and tool calls - only_latest_assistant: If True, returns only the most recent assistant message - regular_list: If True, only includes standard message types - - Returns: - List of messages matching the filter criteria - - Notes: - - Returns empty list if thread doesn't exist - - Filters can be combined for different views of the conversation - """ - thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") - - try: - with open(thread_path, 'r') as f: - thread_data = json.load(f) - messages = thread_data["messages"] - - if only_latest_assistant: - for msg in reversed(messages): - if msg.get('role') == 'assistant': - return [msg] - return [] - - filtered_messages = messages - - if hide_tool_msgs: - filtered_messages = [ - {k: v for k, v in msg.items() if k != 'tool_calls'} - for msg in filtered_messages - if msg.get('role') != 'tool' - ] - - if regular_list: - filtered_messages = [ - msg for msg in filtered_messages - if msg.get('role') in ['system', 'assistant', 'tool', 'user'] - ] - - return filtered_messages - except FileNotFoundError: + """Retrieve messages from a thread with optional filtering.""" + row = await self.db.fetch_one( + "SELECT messages FROM threads WHERE thread_id = ?", + (thread_id,) + ) + if not row: return [] + + messages = json.loads(row[0]) + + if only_latest_assistant: + for msg in reversed(messages): + if msg.get('role') == 'assistant': + return [msg] + return [] + + if hide_tool_msgs: + messages = [ + {k: v for k, v in msg.items() if k != 'tool_calls'} + for msg in messages + if msg.get('role') != 'tool' + ] + + if regular_list: + messages = [ + msg for msg in messages + if msg.get('role') in ['system', 'assistant', 'tool', 'user'] + ] + + return messages + + async def _update_message(self, thread_id: str, message: Dict[str, Any]): + """Update an existing message in the thread.""" + async with self.db.transaction() as conn: + row = await self.db.fetch_one( + "SELECT messages FROM threads WHERE thread_id = ?", + (thread_id,) + ) + if not row: + return + + messages = json.loads(row[0]) + + # Find and update the last assistant message + for i in reversed(range(len(messages))): + if messages[i].get('role') == 'assistant': + messages[i] = message + break + + await conn.execute( + """ + UPDATE threads + SET messages = ?, updated_at = CURRENT_TIMESTAMP + WHERE thread_id = ? + """, + (json.dumps(messages), thread_id) + ) async def cleanup_incomplete_tool_calls(self, thread_id: str): - """Clean up incomplete tool calls in a thread. - - Args: - thread_id: ID of the thread to clean up - - Returns: - bool: True if cleanup was performed, False otherwise - - Notes: - - Adds failure results for incomplete tool calls - - Maintains thread consistency after interruptions - """ - messages = await self.list_messages(thread_id) + """Clean up incomplete tool calls in a thread.""" + messages = await self.get_messages(thread_id) last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None) @@ -253,10 +213,15 @@ class ThreadManager: assistant_index = messages.index(last_assistant_message) messages[assistant_index+1:assistant_index+1] = failed_tool_results - thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") - with open(thread_path, 'w') as f: - json.dump({"messages": messages}, f) - + async with self.db.transaction() as conn: + await conn.execute( + """ + UPDATE threads + SET messages = ?, updated_at = CURRENT_TIMESTAMP + WHERE thread_id = ? + """, + (json.dumps(messages), thread_id) + ) return True return False @@ -326,7 +291,7 @@ class ThreadManager: results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self) try: - messages = await self.list_messages(thread_id) + messages = await self.get_messages(thread_id) prepared_messages = [system_message] + messages if temporary_message: prepared_messages.append(temporary_message) @@ -345,9 +310,8 @@ class ThreadManager: available_functions=available_functions, add_message_callback=self.add_message, update_message_callback=self._update_message, - list_messages_callback=self.list_messages, + get_messages_callback=self.get_messages, parallel_tool_execution=parallel_tool_execution, - threads_dir=self.threads_dir, tool_parser=tool_parser, tool_executor=tool_executor, results_adder=results_adder @@ -405,25 +369,6 @@ class ThreadManager: stream=stream ) - async def _update_message(self, thread_id: str, message: Dict[str, Any]): - """Update an existing message in the thread.""" - thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") - try: - with open(thread_path, 'r') as f: - thread_data = json.load(f) - - # Find and update the last assistant message - for i in reversed(range(len(thread_data["messages"]))): - if thread_data["messages"][i]["role"] == "assistant": - thread_data["messages"][i] = message - break - - with open(thread_path, 'w') as f: - json.dump(thread_data, f) - except Exception as e: - logging.error(f"Error updating message in thread {thread_id}: {e}") - raise e - if __name__ == "__main__": import asyncio from agentpress.examples.example_agent.tools.files_tool import FilesTool @@ -503,7 +448,7 @@ if __name__ == "__main__": print("\n⨠Response completed\n") # Display final thread state - messages = await thread_manager.list_messages(thread_id) + messages = await thread_manager.get_messages(thread_id) print("\nš Final Thread State:") for msg in messages: role = msg.get('role', 'unknown') diff --git a/agentpress/thread_viewer_ui.py b/agentpress/thread_viewer_ui.py index b14f12c3..944cb499 100644 --- a/agentpress/thread_viewer_ui.py +++ b/agentpress/thread_viewer_ui.py @@ -1,21 +1,8 @@ import streamlit as st -import json -import os from datetime import datetime - -def load_thread_files(threads_dir: str): - """Load all thread files from the threads directory.""" - thread_files = [] - if os.path.exists(threads_dir): - for file in os.listdir(threads_dir): - if file.endswith('.json'): - thread_files.append(file) - return thread_files - -def load_thread_content(thread_file: str, threads_dir: str): - """Load the content of a specific thread file.""" - with open(os.path.join(threads_dir, thread_file), 'r') as f: - return json.load(f) +from agentpress.thread_manager import ThreadManager +from agentpress.db_connection import DBConnection +import asyncio def format_message_content(content): """Format message content handling both string and list formats.""" @@ -31,89 +18,123 @@ def format_message_content(content): return "\n".join(formatted_content) return str(content) +async def load_threads(): + """Load all thread IDs from the database.""" + db = DBConnection() + rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC") + return rows + +async def load_thread_content(thread_id: str): + """Load the content of a specific thread from the database.""" + thread_manager = ThreadManager() + return await thread_manager.get_messages(thread_id) + +def render_message(role, content, avatar): + """Render a message with a consistent chat-like style.""" + # Create columns for avatar and message + col1, col2 = st.columns([1, 11]) + + # Style based on role + if role == "assistant": + bgcolor = "rgba(25, 25, 25, 0.05)" + elif role == "user": + bgcolor = "rgba(25, 120, 180, 0.05)" + elif role == "system": + bgcolor = "rgba(180, 25, 25, 0.05)" + else: + bgcolor = "rgba(100, 100, 100, 0.05)" + + # Display avatar in first column + with col1: + st.markdown(f"