- Enhanced storage capabilities for production readiness:

- Added SQLite as primary storage backend for ThreadManager and StateManager
  - Implemented persistent storage with unique store IDs
  - Added CRUD operations for state management
  - Enabled multiple concurrent stores with referential integrity
  - Improved state persistence and retrieval mechanisms
This commit is contained in:
marko-kraemer 2024-11-19 03:46:25 +01:00
parent 13e98678e8
commit cb9f7b616f
20 changed files with 616 additions and 431 deletions

View File

@ -1,3 +1,11 @@
0.1.9
- Enhanced storage capabilities for production readiness:
- Added SQLite as primary storage backend for ThreadManager and StateManager
- Implemented persistent storage with unique store IDs
- Added CRUD operations for state management
- Enabled multiple concurrent stores with referential integrity
- Improved state persistence and retrieval mechanisms
0.1.8
- Added base processor classes for extensible tool handling:
- ToolParserBase: Abstract base class for parsing LLM responses

BIN
agentpress.db Normal file

Binary file not shown.

View File

@ -91,12 +91,13 @@ file contents here
async def run_agent(thread_id: str, use_xml: bool = True, max_iterations: int = 5):
"""Run the development agent with specified configuration."""
thread_manager = ThreadManager()
state_manager = StateManager()
thread_manager.add_tool(FilesTool)
thread_manager.add_tool(TerminalTool)
# Combine base message with XML format if needed
store_id = await StateManager.create_store()
state_manager = StateManager(store_id)
thread_manager.add_tool(FilesTool, store_id=store_id)
thread_manager.add_tool(TerminalTool, store_id=store_id)
system_message = {
"role": "system",
"content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
@ -199,6 +200,7 @@ def main():
async def async_main():
thread_manager = ThreadManager()
thread_id = await thread_manager.create_thread()
await thread_manager.add_message(
thread_id,

View File

@ -1,8 +1,9 @@
import os
import asyncio
from pathlib import Path
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.tools.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional
class FilesTool(Tool):
"""File management tool for creating, updating, and deleting files.
@ -53,11 +54,11 @@ class FilesTool(Tool):
".sql"
}
def __init__(self):
def __init__(self, store_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
self.state_manager = StateManager("state.json")
self.state_manager = StateManager(store_id)
self.SNIPPET_LINES = 4 # Number of context lines to show around edits
asyncio.create_task(self._init_workspace_state())

View File

@ -3,15 +3,16 @@ import asyncio
import subprocess
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional
class TerminalTool(Tool):
"""Terminal command execution tool for workspace operations."""
def __init__(self):
def __init__(self, store_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
self.state_manager = StateManager("state.json")
self.state_manager = StateManager(store_id)
async def _update_command_history(self, command: str, output: str, success: bool):
"""Update command history in state"""

View File

@ -26,14 +26,14 @@ MODULES = {
"processors": {
"required": True,
"files": [
"base_processors.py",
"llm_response_processor.py",
"standard_tool_parser.py",
"standard_tool_executor.py",
"standard_results_adder.py",
"xml_tool_parser.py",
"xml_tool_executor.py",
"xml_results_adder.py"
"processor/base_processors.py",
"processor/llm_response_processor.py",
"processor/standard/standard_tool_parser.py",
"processor/standard/standard_tool_executor.py",
"processor/standard/standard_results_adder.py",
"processor/xml/xml_tool_parser.py",
"processor/xml/xml_tool_executor.py",
"processor/xml/xml_results_adder.py"
],
"description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
},

125
agentpress/db_connection.py Normal file
View File

@ -0,0 +1,125 @@
"""
Centralized database connection management for AgentPress.
"""
import aiosqlite
import logging
from contextlib import asynccontextmanager
import os
import asyncio
class DBConnection:
"""Singleton database connection manager."""
_instance = None
_initialized = False
_db_path = os.path.join(os.getcwd(), "agentpress.db")
_init_lock = asyncio.Lock()
_initialization_task = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
# Start initialization when instance is first created
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
return cls._instance
def __init__(self):
"""No initialization needed in __init__ as it's handled in __new__"""
pass
@classmethod
async def _initialize(cls):
"""Internal initialization method."""
if cls._initialized:
return
async with cls._init_lock:
if cls._initialized: # Double-check after acquiring lock
return
try:
async with aiosqlite.connect(cls._db_path) as db:
# Threads table
await db.execute("""
CREATE TABLE IF NOT EXISTS threads (
thread_id TEXT PRIMARY KEY,
messages TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# State stores table
await db.execute("""
CREATE TABLE IF NOT EXISTS state_stores (
store_id TEXT PRIMARY KEY,
store_data TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
await db.commit()
cls._initialized = True
logging.info("Database schema initialized")
except Exception as e:
logging.error(f"Database initialization error: {e}")
raise
@classmethod
def set_db_path(cls, db_path: str):
"""Set custom database path."""
if cls._initialized:
raise RuntimeError("Cannot change database path after initialization")
cls._db_path = db_path
logging.info(f"Updated database path to: {db_path}")
@asynccontextmanager
async def connection(self):
"""Get a database connection."""
# Wait for initialization to complete if it hasn't already
if self._initialization_task and not self._initialized:
await self._initialization_task
async with aiosqlite.connect(self._db_path) as conn:
try:
yield conn
except Exception as e:
logging.error(f"Database error: {e}")
raise
@asynccontextmanager
async def transaction(self):
"""Execute operations in a transaction."""
async with self.connection() as db:
try:
yield db
await db.commit()
except Exception as e:
await db.rollback()
logging.error(f"Transaction error: {e}")
raise
async def execute(self, query: str, params: tuple = ()):
"""Execute a single query."""
async with self.connection() as db:
try:
result = await db.execute(query, params)
await db.commit()
return result
except Exception as e:
logging.error(f"Query execution error: {e}")
raise
async def fetch_one(self, query: str, params: tuple = ()):
"""Fetch a single row."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchone()
async def fetch_all(self, query: str, params: tuple = ()):
"""Fetch all rows."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchall()

View File

@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
Attributes:
add_message: Callback for adding new messages
update_message: Callback for updating existing messages
list_messages: Callback for retrieving thread messages
get_messages: Callback for retrieving thread messages
message_added: Flag tracking if initial message has been added
"""
@ -184,7 +184,7 @@ class ResultsAdderBase(ABC):
"""
self.add_message = thread_manager.add_message
self.update_message = thread_manager._update_message
self.list_messages = thread_manager.list_messages
self.get_messages = thread_manager.get_messages
self.message_added = False
@abstractmethod

View File

@ -11,10 +11,10 @@ This module provides comprehensive processing of LLM responses, including:
import asyncio
from typing import Callable, Dict, Any, AsyncGenerator, Optional
import logging
from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.standard_tool_parser import StandardToolParser
from agentpress.standard_tool_executor import StandardToolExecutor
from agentpress.standard_results_adder import StandardResultsAdder
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
class LLMResponseProcessor:
"""Handles LLM response processing and tool execution management.
@ -40,9 +40,8 @@ class LLMResponseProcessor:
available_functions: Dict = None,
add_message_callback: Callable = None,
update_message_callback: Callable = None,
list_messages_callback: Callable = None,
get_messages_callback: Callable = None,
parallel_tool_execution: bool = True,
threads_dir: str = "threads",
tool_parser: Optional[ToolParserBase] = None,
tool_executor: Optional[ToolExecutorBase] = None,
results_adder: Optional[ResultsAdderBase] = None,
@ -55,9 +54,8 @@ class LLMResponseProcessor:
available_functions: Dictionary of available tool functions
add_message_callback: Callback for adding messages
update_message_callback: Callback for updating messages
list_messages_callback: Callback for listing messages
get_messages_callback: Callback for listing messages
parallel_tool_execution: Whether to execute tools in parallel
threads_dir: Directory for thread storage
tool_parser: Custom tool parser implementation
tool_executor: Custom tool executor implementation
results_adder: Custom results adder implementation
@ -67,16 +65,15 @@ class LLMResponseProcessor:
self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
self.tool_parser = tool_parser or StandardToolParser()
self.available_functions = available_functions or {}
self.threads_dir = threads_dir
# Create minimal thread manager if needed
if thread_manager is None and (add_message_callback and update_message_callback and list_messages_callback):
if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
class MinimalThreadManager:
def __init__(self, add_msg, update_msg, list_msg):
self.add_message = add_msg
self._update_message = update_msg
self.list_messages = list_msg
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, list_messages_callback)
self.get_messages = list_msg
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)
self.results_adder = results_adder or StandardResultsAdder(thread_manager)

View File

@ -1,5 +1,5 @@
from typing import Dict, Any, List, Optional
from agentpress.base_processors import ResultsAdderBase
from agentpress.processor.base_processors import ResultsAdderBase
# --- Standard Results Adder Implementation ---
@ -81,6 +81,6 @@ class StandardResultsAdder(ResultsAdderBase):
- Checks for duplicate tool results before adding
- Adds result only if tool_call_id is unique
"""
messages = await self.list_messages(thread_id)
messages = await self.get_messages(thread_id)
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
await self.add_message(thread_id, result)

View File

@ -9,7 +9,7 @@ import asyncio
import json
import logging
from typing import Dict, Any, List, Set, Callable, Optional
from agentpress.base_processors import ToolExecutorBase
from agentpress.processor.base_processors import ToolExecutorBase
from agentpress.tool import ToolResult
# --- Standard Tool Executor Implementation ---

View File

@ -1,6 +1,6 @@
import json
from typing import Dict, Any, Optional
from agentpress.base_processors import ToolParserBase
from agentpress.processor.base_processors import ToolParserBase
# --- Standard Tool Parser Implementation ---

View File

@ -1,6 +1,6 @@
import logging
from typing import Dict, Any, List, Optional
from agentpress.base_processors import ResultsAdderBase
from agentpress.processor.base_processors import ResultsAdderBase
class XMLResultsAdder(ResultsAdderBase):
"""XML-specific implementation for handling tool results and message processing.
@ -79,7 +79,7 @@ class XMLResultsAdder(ResultsAdderBase):
"""
try:
# Get the original tool call to find the root tag
messages = await self.list_messages(thread_id)
messages = await self.get_messages(thread_id)
assistant_msg = next((msg for msg in reversed(messages)
if msg['role'] == 'assistant'), None)

View File

@ -9,7 +9,7 @@ from typing import List, Dict, Any, Set, Callable, Optional
import asyncio
import json
import logging
from agentpress.base_processors import ToolExecutorBase
from agentpress.processor.base_processors import ToolExecutorBase
from agentpress.tool import ToolResult
from agentpress.tool_registry import ToolRegistry

View File

@ -7,7 +7,7 @@ complete and streaming responses with robust XML parsing and validation capabili
import logging
from typing import Dict, Any, Optional, List, Tuple
from agentpress.base_processors import ToolParserBase
from agentpress.processor.base_processors import ToolParserBase
import json
import re
from agentpress.tool_registry import ToolRegistry

View File

@ -1,76 +1,100 @@
import json
import os
import logging
from typing import Any
from typing import Any, Optional, List, Dict, Union, AsyncGenerator
from asyncio import Lock
from contextlib import asynccontextmanager
import uuid
from agentpress.db_connection import DBConnection
import asyncio
class StateManager:
"""
Manages persistent state storage for AgentPress components.
The StateManager provides thread-safe access to a JSON-based state store,
allowing components to save and retrieve data across sessions. It handles
concurrent access using asyncio locks and provides atomic operations for
state modifications.
The StateManager provides thread-safe access to a SQLite-based state store,
allowing components to save and retrieve data across sessions. Each store
has a unique ID and contains multiple key-value pairs in a single JSON object.
Attributes:
lock (Lock): Asyncio lock for thread-safe state access
store_file (str): Path to the JSON file storing the state
db (DBConnection): Database connection manager
store_id (str): Unique identifier for this state store
"""
def __init__(self, store_file: str = "state.json"):
def __init__(self, store_id: Optional[str] = None):
"""
Initialize StateManager with custom store file name.
Initialize StateManager with optional store ID.
Args:
store_file (str): Path to the JSON file to store state.
Defaults to "state.json" in the current directory.
store_id (str, optional): Unique identifier for the store. If None, creates new.
"""
self.lock = Lock()
self.store_file = store_file
logging.info(f"StateManager initialized with store file: {store_file}")
self.db = DBConnection()
self.store_id = store_id or str(uuid.uuid4())
logging.info(f"StateManager initialized with store_id: {self.store_id}")
asyncio.create_task(self._ensure_store_exists())
@classmethod
async def create_store(cls) -> str:
"""Create a new state store and return its ID."""
store_id = str(uuid.uuid4())
manager = cls(store_id)
await manager._ensure_store_exists()
return store_id
async def _ensure_store_exists(self):
"""Ensure store exists in database."""
async with self.db.transaction() as conn:
await conn.execute("""
INSERT OR IGNORE INTO state_stores (store_id, store_data)
VALUES (?, ?)
""", (self.store_id, json.dumps({})))
@asynccontextmanager
async def store_scope(self):
"""
Context manager for atomic state operations.
Provides thread-safe access to the state store, handling file I/O
and ensuring proper cleanup. Automatically loads the current state
and saves changes when the context exits.
Provides thread-safe access to the state store, handling database
operations and ensuring proper cleanup.
Yields:
dict: The current state store contents
Raises:
Exception: If there are errors reading from or writing to the store file
Exception: If there are errors with database operations
"""
try:
# Read current state
if os.path.exists(self.store_file):
with open(self.store_file, 'r') as f:
store = json.load(f)
else:
store = {}
yield store
# Write updated state
with open(self.store_file, 'w') as f:
json.dump(store, f, indent=2)
logging.debug("Store saved successfully")
except Exception as e:
logging.error("Error in store operation", exc_info=True)
raise
async with self.lock:
try:
async with self.db.transaction() as conn:
async with conn.execute(
"SELECT store_data FROM state_stores WHERE store_id = ?",
(self.store_id,)
) as cursor:
row = await cursor.fetchone()
store = json.loads(row[0]) if row else {}
yield store
await conn.execute(
"""
UPDATE state_stores
SET store_data = ?, updated_at = CURRENT_TIMESTAMP
WHERE store_id = ?
""",
(json.dumps(store), self.store_id)
)
except Exception as e:
logging.error("Error in store operation", exc_info=True)
raise
async def set(self, key: str, data: Any):
async def set(self, key: str, data: Any) -> Any:
"""
Store any JSON-serializable data with a simple key.
Store any JSON-serializable data with a key.
Args:
key (str): Simple string key like "config" or "settings"
data (Any): Any JSON-serializable data (dict, list, str, int, bool, etc)
data (Any): Any JSON-serializable data
Returns:
Any: The stored data
@ -78,17 +102,12 @@ class StateManager:
Raises:
Exception: If there are errors during storage operation
"""
async with self.lock:
async with self.store_scope() as store:
try:
store[key] = data # Will be JSON serialized when written to file
logging.info(f'Updated store key: {key}')
return data
except Exception as e:
logging.error(f'Error in set: {str(e)}')
raise
async with self.store_scope() as store:
store[key] = data
logging.info(f'Updated store key: {key}')
return data
async def get(self, key: str) -> Any:
async def get(self, key: str) -> Optional[Any]:
"""
Get data for a key.
@ -97,9 +116,6 @@ class StateManager:
Returns:
Any: The stored data for the key, or None if key not found
Note:
This operation is read-only and doesn't require locking
"""
async with self.store_scope() as store:
if key in store:
@ -115,17 +131,31 @@ class StateManager:
Args:
key (str): Simple string key like "config" or "settings"
Note:
No error is raised if the key doesn't exist
"""
async with self.lock:
async with self.store_scope() as store:
if key in store:
del store[key]
logging.info(f"Deleted key: {key}")
else:
logging.info(f"Key not found for deletion: {key}")
async with self.store_scope() as store:
if key in store:
del store[key]
logging.info(f"Deleted key: {key}")
async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]:
"""Update existing data for a key by merging dictionaries."""
async with self.store_scope() as store:
if key in store and isinstance(store[key], dict):
store[key].update(data)
logging.info(f'Updated store key: {key}')
return store[key]
return None
async def append(self, key: str, item: Any) -> Optional[List[Any]]:
"""Append an item to a list stored at key."""
async with self.store_scope() as store:
if key not in store:
store[key] = []
if isinstance(store[key], list):
store[key].append(item)
logging.info(f'Appended to key: {key}')
return store[key]
return None
async def export_store(self) -> dict:
"""
@ -133,9 +163,6 @@ class StateManager:
Returns:
dict: Complete contents of the state store
Note:
This operation is read-only and returns a copy of the store
"""
async with self.store_scope() as store:
logging.info(f"Store content: {store}")
@ -148,7 +175,29 @@ class StateManager:
Removes all data from the store, resetting it to an empty state.
This operation is atomic and thread-safe.
"""
async with self.lock:
async with self.store_scope() as store:
store.clear()
logging.info("Cleared store")
async with self.store_scope() as store:
store.clear()
logging.info("Cleared store")
@classmethod
async def list_stores(cls) -> List[Dict[str, Any]]:
"""
List all available state stores.
Returns:
List of store information including IDs and timestamps
"""
db = DBConnection()
async with db.transaction() as conn:
async with conn.execute(
"SELECT store_id, created_at, updated_at FROM state_stores ORDER BY updated_at DESC"
) as cursor:
stores = [
{
"store_id": row[0],
"created_at": row[1],
"updated_at": row[2]
}
for row in await cursor.fetchall()
]
return stores

View File

@ -11,21 +11,22 @@ This module provides comprehensive conversation management, including:
import json
import logging
import os
import asyncio
import uuid
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator
from agentpress.llm import make_llm_api_call
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
from agentpress.llm_response_processor import LLMResponseProcessor
from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.processor.llm_response_processor import LLMResponseProcessor
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.db_connection import DBConnection
from agentpress.xml_tool_parser import XMLToolParser
from agentpress.xml_tool_executor import XMLToolExecutor
from agentpress.xml_results_adder import XMLResultsAdder
from agentpress.standard_tool_parser import StandardToolParser
from agentpress.standard_tool_executor import StandardToolExecutor
from agentpress.standard_results_adder import StandardResultsAdder
from agentpress.processor.xml.xml_tool_parser import XMLToolParser
from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
class ThreadManager:
"""Manages conversation threads with LLM models and tool execution.
@ -33,204 +34,163 @@ class ThreadManager:
Provides comprehensive conversation management, handling message threading,
tool registration, and LLM interactions with support for both standard and
XML-based tool execution patterns.
Attributes:
threads_dir (str): Directory for storing thread files
tool_registry (ToolRegistry): Registry for managing available tools
Methods:
add_tool: Register a tool with optional function filtering
create_thread: Create a new conversation thread
add_message: Add a message to a thread
list_messages: Retrieve messages from a thread
run_thread: Execute a conversation thread with LLM
"""
def __init__(self, threads_dir: str = "threads"):
"""Initialize ThreadManager.
Args:
threads_dir: Directory to store thread files
Notes:
Creates the threads directory if it doesn't exist
"""
self.threads_dir = threads_dir
def __init__(self):
"""Initialize ThreadManager."""
self.db = DBConnection()
self.tool_registry = ToolRegistry()
os.makedirs(self.threads_dir, exist_ok=True)
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager.
Args:
tool_class: The tool class to register
function_names: Optional list of specific functions to register
**kwargs: Additional arguments passed to tool initialization
Notes:
- If function_names is None, all functions are registered
- Tool instances are created with provided kwargs
"""
"""Add a tool to the ThreadManager."""
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
async def create_thread(self) -> str:
"""Create a new conversation thread.
Returns:
str: Unique thread ID for the created thread
Raises:
IOError: If thread file creation fails
Notes:
Creates a new thread file with an empty messages list
"""
"""Create a new conversation thread."""
thread_id = str(uuid.uuid4())
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
with open(thread_path, 'w') as f:
json.dump({"messages": []}, f)
await self.db.execute(
"INSERT INTO threads (thread_id, messages) VALUES (?, ?)",
(thread_id, json.dumps([]))
)
return thread_id
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
"""Add a message to an existing thread.
Args:
thread_id: ID of the target thread
message_data: Message content and metadata
images: Optional list of image data dictionaries
Raises:
FileNotFoundError: If thread doesn't exist
Exception: For other operation failures
Notes:
- Handles cleanup of incomplete tool calls
- Supports both text and image content
- Converts ToolResult instances to strings
"""
"""Add a message to an existing thread."""
logging.info(f"Adding message to thread {thread_id} with images: {images}")
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
try:
with open(thread_path, 'r') as f:
thread_data = json.load(f)
messages = thread_data["messages"]
# Handle cleanup of incomplete tool calls
if message_data['role'] == 'user':
last_assistant_index = next((i for i in reversed(range(len(messages)))
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
if last_assistant_index is not None:
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
if msg['role'] == 'tool')
async with self.db.transaction() as conn:
# Handle cleanup of incomplete tool calls
if message_data['role'] == 'user':
messages = await self.get_messages(thread_id)
last_assistant_index = next((i for i in reversed(range(len(messages)))
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
if tool_call_count != tool_response_count:
await self.cleanup_incomplete_tool_calls(thread_id)
if last_assistant_index is not None:
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
if msg['role'] == 'tool')
if tool_call_count != tool_response_count:
await self.cleanup_incomplete_tool_calls(thread_id)
# Convert ToolResult instances to strings
for key, value in message_data.items():
if isinstance(value, ToolResult):
message_data[key] = str(value)
# Convert ToolResult instances to strings
for key, value in message_data.items():
if isinstance(value, ToolResult):
message_data[key] = str(value)
# Handle image attachments
if images:
if isinstance(message_data['content'], str):
message_data['content'] = [{"type": "text", "text": message_data['content']}]
elif not isinstance(message_data['content'], list):
message_data['content'] = []
# Handle image attachments
if images:
if isinstance(message_data['content'], str):
message_data['content'] = [{"type": "text", "text": message_data['content']}]
elif not isinstance(message_data['content'], list):
message_data['content'] = []
for image in images:
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:{image['content_type']};base64,{image['base64']}",
"detail": "high"
for image in images:
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:{image['content_type']};base64,{image['base64']}",
"detail": "high"
}
}
}
message_data['content'].append(image_content)
message_data['content'].append(image_content)
messages.append(message_data)
thread_data["messages"] = messages
# Get current messages
row = await self.db.fetch_one(
"SELECT messages FROM threads WHERE thread_id = ?",
(thread_id,)
)
if not row:
raise ValueError(f"Thread {thread_id} not found")
messages = json.loads(row[0])
messages.append(message_data)
# Update thread
await conn.execute(
"""
UPDATE threads
SET messages = ?, updated_at = CURRENT_TIMESTAMP
WHERE thread_id = ?
""",
(json.dumps(messages), thread_id)
)
logging.info(f"Message added to thread {thread_id}: {message_data}")
with open(thread_path, 'w') as f:
json.dump(thread_data, f)
logging.info(f"Message added to thread {thread_id}: {message_data}")
except Exception as e:
logging.error(f"Failed to add message to thread {thread_id}: {e}")
raise e
async def list_messages(
async def get_messages(
self,
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
regular_list: bool = True
) -> List[Dict[str, Any]]:
"""Retrieve messages from a thread with optional filtering.
Args:
thread_id: ID of the thread to retrieve messages from
hide_tool_msgs: If True, excludes tool messages and tool calls
only_latest_assistant: If True, returns only the most recent assistant message
regular_list: If True, only includes standard message types
Returns:
List of messages matching the filter criteria
Notes:
- Returns empty list if thread doesn't exist
- Filters can be combined for different views of the conversation
"""
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
try:
with open(thread_path, 'r') as f:
thread_data = json.load(f)
messages = thread_data["messages"]
if only_latest_assistant:
for msg in reversed(messages):
if msg.get('role') == 'assistant':
return [msg]
return []
filtered_messages = messages
if hide_tool_msgs:
filtered_messages = [
{k: v for k, v in msg.items() if k != 'tool_calls'}
for msg in filtered_messages
if msg.get('role') != 'tool'
]
if regular_list:
filtered_messages = [
msg for msg in filtered_messages
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
]
return filtered_messages
except FileNotFoundError:
"""Retrieve messages from a thread with optional filtering."""
row = await self.db.fetch_one(
"SELECT messages FROM threads WHERE thread_id = ?",
(thread_id,)
)
if not row:
return []
messages = json.loads(row[0])
if only_latest_assistant:
for msg in reversed(messages):
if msg.get('role') == 'assistant':
return [msg]
return []
if hide_tool_msgs:
messages = [
{k: v for k, v in msg.items() if k != 'tool_calls'}
for msg in messages
if msg.get('role') != 'tool'
]
if regular_list:
messages = [
msg for msg in messages
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
]
return messages
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
"""Update an existing message in the thread."""
async with self.db.transaction() as conn:
row = await self.db.fetch_one(
"SELECT messages FROM threads WHERE thread_id = ?",
(thread_id,)
)
if not row:
return
messages = json.loads(row[0])
# Find and update the last assistant message
for i in reversed(range(len(messages))):
if messages[i].get('role') == 'assistant':
messages[i] = message
break
await conn.execute(
"""
UPDATE threads
SET messages = ?, updated_at = CURRENT_TIMESTAMP
WHERE thread_id = ?
""",
(json.dumps(messages), thread_id)
)
async def cleanup_incomplete_tool_calls(self, thread_id: str):
"""Clean up incomplete tool calls in a thread.
Args:
thread_id: ID of the thread to clean up
Returns:
bool: True if cleanup was performed, False otherwise
Notes:
- Adds failure results for incomplete tool calls
- Maintains thread consistency after interruptions
"""
messages = await self.list_messages(thread_id)
"""Clean up incomplete tool calls in a thread."""
messages = await self.get_messages(thread_id)
last_assistant_message = next((m for m in reversed(messages)
if m['role'] == 'assistant' and 'tool_calls' in m), None)
@ -253,10 +213,15 @@ class ThreadManager:
assistant_index = messages.index(last_assistant_message)
messages[assistant_index+1:assistant_index+1] = failed_tool_results
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
with open(thread_path, 'w') as f:
json.dump({"messages": messages}, f)
async with self.db.transaction() as conn:
await conn.execute(
"""
UPDATE threads
SET messages = ?, updated_at = CURRENT_TIMESTAMP
WHERE thread_id = ?
""",
(json.dumps(messages), thread_id)
)
return True
return False
@ -326,7 +291,7 @@ class ThreadManager:
results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self)
try:
messages = await self.list_messages(thread_id)
messages = await self.get_messages(thread_id)
prepared_messages = [system_message] + messages
if temporary_message:
prepared_messages.append(temporary_message)
@ -345,9 +310,8 @@ class ThreadManager:
available_functions=available_functions,
add_message_callback=self.add_message,
update_message_callback=self._update_message,
list_messages_callback=self.list_messages,
get_messages_callback=self.get_messages,
parallel_tool_execution=parallel_tool_execution,
threads_dir=self.threads_dir,
tool_parser=tool_parser,
tool_executor=tool_executor,
results_adder=results_adder
@ -405,25 +369,6 @@ class ThreadManager:
stream=stream
)
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
"""Update an existing message in the thread."""
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
try:
with open(thread_path, 'r') as f:
thread_data = json.load(f)
# Find and update the last assistant message
for i in reversed(range(len(thread_data["messages"]))):
if thread_data["messages"][i]["role"] == "assistant":
thread_data["messages"][i] = message
break
with open(thread_path, 'w') as f:
json.dump(thread_data, f)
except Exception as e:
logging.error(f"Error updating message in thread {thread_id}: {e}")
raise e
if __name__ == "__main__":
import asyncio
from agentpress.examples.example_agent.tools.files_tool import FilesTool
@ -503,7 +448,7 @@ if __name__ == "__main__":
print("\n✨ Response completed\n")
# Display final thread state
messages = await thread_manager.list_messages(thread_id)
messages = await thread_manager.get_messages(thread_id)
print("\n📝 Final Thread State:")
for msg in messages:
role = msg.get('role', 'unknown')

View File

@ -1,21 +1,8 @@
import streamlit as st
import json
import os
from datetime import datetime
def load_thread_files(threads_dir: str):
"""Load all thread files from the threads directory."""
thread_files = []
if os.path.exists(threads_dir):
for file in os.listdir(threads_dir):
if file.endswith('.json'):
thread_files.append(file)
return thread_files
def load_thread_content(thread_file: str, threads_dir: str):
"""Load the content of a specific thread file."""
with open(os.path.join(threads_dir, thread_file), 'r') as f:
return json.load(f)
from agentpress.thread_manager import ThreadManager
from agentpress.db_connection import DBConnection
import asyncio
def format_message_content(content):
"""Format message content handling both string and list formats."""
@ -31,89 +18,123 @@ def format_message_content(content):
return "\n".join(formatted_content)
return str(content)
async def load_threads():
"""Load all thread IDs from the database."""
db = DBConnection()
rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC")
return rows
async def load_thread_content(thread_id: str):
"""Load the content of a specific thread from the database."""
thread_manager = ThreadManager()
return await thread_manager.get_messages(thread_id)
def render_message(role, content, avatar):
"""Render a message with a consistent chat-like style."""
# Create columns for avatar and message
col1, col2 = st.columns([1, 11])
# Style based on role
if role == "assistant":
bgcolor = "rgba(25, 25, 25, 0.05)"
elif role == "user":
bgcolor = "rgba(25, 120, 180, 0.05)"
elif role == "system":
bgcolor = "rgba(180, 25, 25, 0.05)"
else:
bgcolor = "rgba(100, 100, 100, 0.05)"
# Display avatar in first column
with col1:
st.markdown(f"<div style='text-align: center; font-size: 24px;'>{avatar}</div>", unsafe_allow_html=True)
# Display message in second column
with col2:
st.markdown(
f"""
<div style='background-color: {bgcolor}; padding: 10px; border-radius: 5px;'>
<strong>{role.upper()}</strong><br>
{content}
</div>
""",
unsafe_allow_html=True
)
def main():
st.title("Thread Viewer")
# Directory selection in sidebar
st.sidebar.title("Configuration")
# Initialize thread data in session state
if 'threads' not in st.session_state:
st.session_state.threads = asyncio.run(load_threads())
# Initialize session state with default directory
if 'threads_dir' not in st.session_state:
default_dir = "./threads"
if os.path.exists(default_dir):
st.session_state.threads_dir = default_dir
# Use Streamlit's file uploader for directory selection
uploaded_dir = st.sidebar.text_input(
"Enter threads directory path",
value="./threads" if not st.session_state.threads_dir else st.session_state.threads_dir,
placeholder="/path/to/threads",
help="Enter the full path to your threads directory"
# Thread selection in sidebar
st.sidebar.title("Select Thread")
if not st.session_state.threads:
st.warning("No threads found in database")
return
# Format thread options with creation date
thread_options = {
f"{row[0]} ({datetime.fromisoformat(row[1]).strftime('%Y-%m-%d %H:%M')})"
: row[0] for row in st.session_state.threads
}
selected_thread_display = st.sidebar.selectbox(
"Choose a thread",
options=list(thread_options.keys()),
)
# Automatically load directory if it exists
if os.path.exists(uploaded_dir):
st.session_state.threads_dir = uploaded_dir
else:
st.sidebar.error("Directory not found!")
if st.session_state.threads_dir:
st.sidebar.success(f"Selected directory: {st.session_state.threads_dir}")
threads_dir = st.session_state.threads_dir
if selected_thread_display:
# Get the actual thread ID from the display string
selected_thread_id = thread_options[selected_thread_display]
# Thread selection
st.sidebar.title("Select Thread")
thread_files = load_thread_files(threads_dir)
# Display thread ID in sidebar
st.sidebar.text(f"Thread ID: {selected_thread_id}")
if not thread_files:
st.warning(f"No thread files found in '{threads_dir}'")
return
# Add refresh button
if st.sidebar.button("🔄 Refresh Thread"):
st.session_state.threads = asyncio.run(load_threads())
st.experimental_rerun()
selected_thread = st.sidebar.selectbox(
"Choose a thread file",
thread_files,
format_func=lambda x: f"Thread: {x.replace('.json', '')}"
)
# Load and display messages
messages = asyncio.run(load_thread_content(selected_thread_id))
if selected_thread:
thread_data = load_thread_content(selected_thread, threads_dir)
messages = thread_data.get("messages", [])
# Display messages in chat-like interface
for message in messages:
role = message.get("role", "unknown")
content = message.get("content", "")
# Display thread ID in sidebar
st.sidebar.text(f"Thread ID: {selected_thread.replace('.json', '')}")
# Determine avatar based on role
if role == "assistant":
avatar = "🤖"
elif role == "user":
avatar = "👤"
elif role == "system":
avatar = "⚙️"
elif role == "tool":
avatar = "🔧"
else:
avatar = ""
# Display messages in chat-like interface
for message in messages:
role = message.get("role", "unknown")
content = message.get("content", "")
# Determine avatar based on role
if role == "assistant":
avatar = "🤖"
elif role == "user":
avatar = "👤"
elif role == "system":
avatar = "⚙️"
elif role == "tool":
avatar = "🔧"
else:
avatar = ""
# Format the message container
with st.chat_message(role, avatar=avatar):
formatted_content = format_message_content(content)
st.markdown(formatted_content)
if "tool_calls" in message:
st.markdown("**Tool Calls:**")
for tool_call in message["tool_calls"]:
st.code(
f"Function: {tool_call['function']['name']}\n"
f"Arguments: {tool_call['function']['arguments']}",
language="json"
)
else:
st.sidebar.warning("Please enter and load a threads directory")
# Format the content
formatted_content = format_message_content(content)
# Render the message
render_message(role, formatted_content, avatar)
# Display tool calls if present
if "tool_calls" in message:
with st.expander("🛠️ Tool Calls"):
for tool_call in message["tool_calls"]:
st.code(
f"Function: {tool_call['function']['name']}\n"
f"Arguments: {tool_call['function']['arguments']}",
language="json"
)
# Add some spacing between messages
st.markdown("<br>", unsafe_allow_html=True)
if __name__ == "__main__":
main()

103
poetry.lock generated
View File

@ -160,27 +160,44 @@ files = [
frozenlist = ">=1.1.0"
[[package]]
name = "altair"
version = "5.4.1"
description = "Vega-Altair: A declarative statistical visualization library for Python."
name = "aiosqlite"
version = "0.20.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = false
python-versions = ">=3.8"
files = [
{file = "altair-5.4.1-py3-none-any.whl", hash = "sha256:0fb130b8297a569d08991fb6fe763582e7569f8a04643bbd9212436e3be04aef"},
{file = "altair-5.4.1.tar.gz", hash = "sha256:0ce8c2e66546cb327e5f2d7572ec0e7c6feece816203215613962f0ec1d76a82"},
{file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"},
{file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"},
]
[package.dependencies]
jinja2 = "*"
jsonschema = ">=3.0"
narwhals = ">=1.5.2"
packaging = "*"
typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""}
typing_extensions = ">=4.0"
[package.extras]
all = ["altair-tiles (>=0.3.0)", "anywidget (>=0.9.0)", "numpy", "pandas (>=0.25.3)", "pyarrow (>=11)", "vega-datasets (>=0.9.0)", "vegafusion[embed] (>=1.6.6)", "vl-convert-python (>=1.6.0)"]
dev = ["geopandas", "hatch", "ibis-framework[polars]", "ipython[kernel]", "mistune", "mypy", "pandas (>=0.25.3)", "pandas-stubs", "polars (>=0.20.3)", "pytest", "pytest-cov", "pytest-xdist[psutil] (>=3.5,<4.0)", "ruff (>=0.6.0)", "types-jsonschema", "types-setuptools"]
doc = ["docutils", "jinja2", "myst-parser", "numpydoc", "pillow (>=9,<10)", "pydata-sphinx-theme (>=0.14.1)", "scipy", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinxext-altair"]
dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
[[package]]
name = "altair"
version = "4.2.2"
description = "Altair: A declarative statistical visualization library for Python."
optional = false
python-versions = ">=3.7"
files = [
{file = "altair-4.2.2-py3-none-any.whl", hash = "sha256:8b45ebeaf8557f2d760c5c77b79f02ae12aee7c46c27c06014febab6f849bc87"},
{file = "altair-4.2.2.tar.gz", hash = "sha256:39399a267c49b30d102c10411e67ab26374156a84b1aeb9fcd15140429ba49c5"},
]
[package.dependencies]
entrypoints = "*"
jinja2 = "*"
jsonschema = ">=3.0"
numpy = "*"
pandas = ">=0.18"
toolz = "*"
[package.extras]
dev = ["black", "docutils", "flake8", "ipython", "m2r", "mistune (<2.0.0)", "pytest", "recommonmark", "sphinx", "vega-datasets"]
[[package]]
name = "annotated-types"
@ -226,6 +243,19 @@ files = [
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
]
[[package]]
name = "asyncio"
version = "3.4.3"
description = "reference implementation of PEP 3156"
optional = false
python-versions = "*"
files = [
{file = "asyncio-3.4.3-cp33-none-win32.whl", hash = "sha256:b62c9157d36187eca799c378e572c969f0da87cd5fc42ca372d92cdb06e7e1de"},
{file = "asyncio-3.4.3-cp33-none-win_amd64.whl", hash = "sha256:c46a87b48213d7464f22d9a497b9eef8c1928b68320a2fa94240f969f6fec08c"},
{file = "asyncio-3.4.3-py3-none-any.whl", hash = "sha256:c4d18b22701821de07bd6aea8b53d21449ec0ec5680645e5317062ea21817d2d"},
{file = "asyncio-3.4.3.tar.gz", hash = "sha256:83360ff8bc97980e4ff25c964c7bd3923d333d177aa4f7fb736b019f26c7cb41"},
]
[[package]]
name = "attrs"
version = "24.2.0"
@ -428,6 +458,17 @@ files = [
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]]
name = "entrypoints"
version = "0.4"
description = "Discover and load entry points from installed packages."
optional = false
python-versions = ">=3.6"
files = [
{file = "entrypoints-0.4-py3-none-any.whl", hash = "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"},
{file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"},
]
[[package]]
name = "exceptiongroup"
version = "1.2.2"
@ -1140,25 +1181,6 @@ files = [
[package.dependencies]
typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
[[package]]
name = "narwhals"
version = "1.12.1"
description = "Extremely lightweight compatibility layer between dataframe libraries"
optional = false
python-versions = ">=3.8"
files = [
{file = "narwhals-1.12.1-py3-none-any.whl", hash = "sha256:e251cb5fe4cabdcabb847d359f5de2b81df773df47e46f858fd5570c936919c4"},
{file = "narwhals-1.12.1.tar.gz", hash = "sha256:65ff0d1e8b509df8b52b395e8d5fe96751a68657bdabf0f3057a970ec2cd1809"},
]
[package.extras]
cudf = ["cudf (>=23.08.00)"]
dask = ["dask[dataframe] (>=2024.7)"]
modin = ["modin"]
pandas = ["pandas (>=0.25.3)"]
polars = ["polars (>=0.20.3)"]
pyarrow = ["pyarrow (>=11.0.0)"]
[[package]]
name = "numpy"
version = "2.0.2"
@ -1301,9 +1323,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -1690,8 +1712,8 @@ files = [
annotated-types = ">=0.6.0"
pydantic-core = "2.23.4"
typing-extensions = [
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
]
[package.extras]
@ -2612,6 +2634,17 @@ files = [
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
]
[[package]]
name = "toolz"
version = "1.0.0"
description = "List processing tools and functional utilities"
optional = false
python-versions = ">=3.8"
files = [
{file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"},
{file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"},
]
[[package]]
name = "tornado"
version = "6.4.1"
@ -2893,4 +2926,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "d4c229cc0ecd64741dcf73186dce1c90100230e57acca3aa589e66dbb3052e9d"
content-hash = "32b3aefcb3a32a251cd1de7abaa721729c4e2ce2640625a80fcf51a0e3d5da2f"

View File

@ -30,6 +30,9 @@ packaging = "^23.2"
setuptools = "^75.3.0"
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
aiosqlite = "^0.20.0"
asyncio = "^3.4.3"
altair = "4.2.2"
[tool.poetry.scripts]
agentpress = "agentpress.cli:main"