mirror of https://github.com/kortix-ai/suna.git
- Enhanced storage capabilities for production readiness:
- Added SQLite as primary storage backend for ThreadManager and StateManager - Implemented persistent storage with unique store IDs - Added CRUD operations for state management - Enabled multiple concurrent stores with referential integrity - Improved state persistence and retrieval mechanisms
This commit is contained in:
parent
13e98678e8
commit
cb9f7b616f
|
@ -1,3 +1,11 @@
|
|||
0.1.9
|
||||
- Enhanced storage capabilities for production readiness:
|
||||
- Added SQLite as primary storage backend for ThreadManager and StateManager
|
||||
- Implemented persistent storage with unique store IDs
|
||||
- Added CRUD operations for state management
|
||||
- Enabled multiple concurrent stores with referential integrity
|
||||
- Improved state persistence and retrieval mechanisms
|
||||
|
||||
0.1.8
|
||||
- Added base processor classes for extensible tool handling:
|
||||
- ToolParserBase: Abstract base class for parsing LLM responses
|
||||
|
|
Binary file not shown.
|
@ -91,12 +91,13 @@ file contents here
|
|||
async def run_agent(thread_id: str, use_xml: bool = True, max_iterations: int = 5):
|
||||
"""Run the development agent with specified configuration."""
|
||||
thread_manager = ThreadManager()
|
||||
state_manager = StateManager()
|
||||
|
||||
thread_manager.add_tool(FilesTool)
|
||||
thread_manager.add_tool(TerminalTool)
|
||||
|
||||
# Combine base message with XML format if needed
|
||||
store_id = await StateManager.create_store()
|
||||
state_manager = StateManager(store_id)
|
||||
|
||||
thread_manager.add_tool(FilesTool, store_id=store_id)
|
||||
thread_manager.add_tool(TerminalTool, store_id=store_id)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
|
||||
|
@ -199,6 +200,7 @@ def main():
|
|||
|
||||
async def async_main():
|
||||
thread_manager = ThreadManager()
|
||||
|
||||
thread_id = await thread_manager.create_thread()
|
||||
await thread_manager.add_message(
|
||||
thread_id,
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.tools.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.state_manager import StateManager
|
||||
from typing import Optional
|
||||
|
||||
class FilesTool(Tool):
|
||||
"""File management tool for creating, updating, and deleting files.
|
||||
|
@ -53,11 +54,11 @@ class FilesTool(Tool):
|
|||
".sql"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, store_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
self.state_manager = StateManager("state.json")
|
||||
self.state_manager = StateManager(store_id)
|
||||
self.SNIPPET_LINES = 4 # Number of context lines to show around edits
|
||||
asyncio.create_task(self._init_workspace_state())
|
||||
|
||||
|
|
|
@ -3,15 +3,16 @@ import asyncio
|
|||
import subprocess
|
||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.state_manager import StateManager
|
||||
from typing import Optional
|
||||
|
||||
class TerminalTool(Tool):
|
||||
"""Terminal command execution tool for workspace operations."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, store_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
self.state_manager = StateManager("state.json")
|
||||
self.state_manager = StateManager(store_id)
|
||||
|
||||
async def _update_command_history(self, command: str, output: str, success: bool):
|
||||
"""Update command history in state"""
|
||||
|
|
|
@ -26,14 +26,14 @@ MODULES = {
|
|||
"processors": {
|
||||
"required": True,
|
||||
"files": [
|
||||
"base_processors.py",
|
||||
"llm_response_processor.py",
|
||||
"standard_tool_parser.py",
|
||||
"standard_tool_executor.py",
|
||||
"standard_results_adder.py",
|
||||
"xml_tool_parser.py",
|
||||
"xml_tool_executor.py",
|
||||
"xml_results_adder.py"
|
||||
"processor/base_processors.py",
|
||||
"processor/llm_response_processor.py",
|
||||
"processor/standard/standard_tool_parser.py",
|
||||
"processor/standard/standard_tool_executor.py",
|
||||
"processor/standard/standard_results_adder.py",
|
||||
"processor/xml/xml_tool_parser.py",
|
||||
"processor/xml/xml_tool_executor.py",
|
||||
"processor/xml/xml_results_adder.py"
|
||||
],
|
||||
"description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
|
||||
},
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
"""
|
||||
Centralized database connection management for AgentPress.
|
||||
"""
|
||||
|
||||
import aiosqlite
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
class DBConnection:
|
||||
"""Singleton database connection manager."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
_db_path = os.path.join(os.getcwd(), "agentpress.db")
|
||||
_init_lock = asyncio.Lock()
|
||||
_initialization_task = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
# Start initialization when instance is first created
|
||||
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""No initialization needed in __init__ as it's handled in __new__"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _initialize(cls):
|
||||
"""Internal initialization method."""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
||||
async with cls._init_lock:
|
||||
if cls._initialized: # Double-check after acquiring lock
|
||||
return
|
||||
|
||||
try:
|
||||
async with aiosqlite.connect(cls._db_path) as db:
|
||||
# Threads table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
messages TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# State stores table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS state_stores (
|
||||
store_id TEXT PRIMARY KEY,
|
||||
store_data TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await db.commit()
|
||||
cls._initialized = True
|
||||
logging.info("Database schema initialized")
|
||||
except Exception as e:
|
||||
logging.error(f"Database initialization error: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def set_db_path(cls, db_path: str):
|
||||
"""Set custom database path."""
|
||||
if cls._initialized:
|
||||
raise RuntimeError("Cannot change database path after initialization")
|
||||
cls._db_path = db_path
|
||||
logging.info(f"Updated database path to: {db_path}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self):
|
||||
"""Get a database connection."""
|
||||
# Wait for initialization to complete if it hasn't already
|
||||
if self._initialization_task and not self._initialized:
|
||||
await self._initialization_task
|
||||
|
||||
async with aiosqlite.connect(self._db_path) as conn:
|
||||
try:
|
||||
yield conn
|
||||
except Exception as e:
|
||||
logging.error(f"Database error: {e}")
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Execute operations in a transaction."""
|
||||
async with self.connection() as db:
|
||||
try:
|
||||
yield db
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logging.error(f"Transaction error: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, query: str, params: tuple = ()):
|
||||
"""Execute a single query."""
|
||||
async with self.connection() as db:
|
||||
try:
|
||||
result = await db.execute(query, params)
|
||||
await db.commit()
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Query execution error: {e}")
|
||||
raise
|
||||
|
||||
async def fetch_one(self, query: str, params: tuple = ()):
|
||||
"""Fetch a single row."""
|
||||
async with self.connection() as db:
|
||||
async with db.execute(query, params) as cursor:
|
||||
return await cursor.fetchone()
|
||||
|
||||
async def fetch_all(self, query: str, params: tuple = ()):
|
||||
"""Fetch all rows."""
|
||||
async with self.connection() as db:
|
||||
async with db.execute(query, params) as cursor:
|
||||
return await cursor.fetchall()
|
|
@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
|
|||
Attributes:
|
||||
add_message: Callback for adding new messages
|
||||
update_message: Callback for updating existing messages
|
||||
list_messages: Callback for retrieving thread messages
|
||||
get_messages: Callback for retrieving thread messages
|
||||
message_added: Flag tracking if initial message has been added
|
||||
"""
|
||||
|
||||
|
@ -184,7 +184,7 @@ class ResultsAdderBase(ABC):
|
|||
"""
|
||||
self.add_message = thread_manager.add_message
|
||||
self.update_message = thread_manager._update_message
|
||||
self.list_messages = thread_manager.list_messages
|
||||
self.get_messages = thread_manager.get_messages
|
||||
self.message_added = False
|
||||
|
||||
@abstractmethod
|
|
@ -11,10 +11,10 @@ This module provides comprehensive processing of LLM responses, including:
|
|||
import asyncio
|
||||
from typing import Callable, Dict, Any, AsyncGenerator, Optional
|
||||
import logging
|
||||
from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
||||
from agentpress.standard_tool_parser import StandardToolParser
|
||||
from agentpress.standard_tool_executor import StandardToolExecutor
|
||||
from agentpress.standard_results_adder import StandardResultsAdder
|
||||
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
||||
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
|
||||
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
|
||||
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
|
||||
|
||||
class LLMResponseProcessor:
|
||||
"""Handles LLM response processing and tool execution management.
|
||||
|
@ -40,9 +40,8 @@ class LLMResponseProcessor:
|
|||
available_functions: Dict = None,
|
||||
add_message_callback: Callable = None,
|
||||
update_message_callback: Callable = None,
|
||||
list_messages_callback: Callable = None,
|
||||
get_messages_callback: Callable = None,
|
||||
parallel_tool_execution: bool = True,
|
||||
threads_dir: str = "threads",
|
||||
tool_parser: Optional[ToolParserBase] = None,
|
||||
tool_executor: Optional[ToolExecutorBase] = None,
|
||||
results_adder: Optional[ResultsAdderBase] = None,
|
||||
|
@ -55,9 +54,8 @@ class LLMResponseProcessor:
|
|||
available_functions: Dictionary of available tool functions
|
||||
add_message_callback: Callback for adding messages
|
||||
update_message_callback: Callback for updating messages
|
||||
list_messages_callback: Callback for listing messages
|
||||
get_messages_callback: Callback for listing messages
|
||||
parallel_tool_execution: Whether to execute tools in parallel
|
||||
threads_dir: Directory for thread storage
|
||||
tool_parser: Custom tool parser implementation
|
||||
tool_executor: Custom tool executor implementation
|
||||
results_adder: Custom results adder implementation
|
||||
|
@ -67,16 +65,15 @@ class LLMResponseProcessor:
|
|||
self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
|
||||
self.tool_parser = tool_parser or StandardToolParser()
|
||||
self.available_functions = available_functions or {}
|
||||
self.threads_dir = threads_dir
|
||||
|
||||
# Create minimal thread manager if needed
|
||||
if thread_manager is None and (add_message_callback and update_message_callback and list_messages_callback):
|
||||
if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
|
||||
class MinimalThreadManager:
|
||||
def __init__(self, add_msg, update_msg, list_msg):
|
||||
self.add_message = add_msg
|
||||
self._update_message = update_msg
|
||||
self.list_messages = list_msg
|
||||
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, list_messages_callback)
|
||||
self.get_messages = list_msg
|
||||
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)
|
||||
|
||||
self.results_adder = results_adder or StandardResultsAdder(thread_manager)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, Any, List, Optional
|
||||
from agentpress.base_processors import ResultsAdderBase
|
||||
from agentpress.processor.base_processors import ResultsAdderBase
|
||||
|
||||
# --- Standard Results Adder Implementation ---
|
||||
|
||||
|
@ -81,6 +81,6 @@ class StandardResultsAdder(ResultsAdderBase):
|
|||
- Checks for duplicate tool results before adding
|
||||
- Adds result only if tool_call_id is unique
|
||||
"""
|
||||
messages = await self.list_messages(thread_id)
|
||||
messages = await self.get_messages(thread_id)
|
||||
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
|
||||
await self.add_message(thread_id, result)
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Set, Callable, Optional
|
||||
from agentpress.base_processors import ToolExecutorBase
|
||||
from agentpress.processor.base_processors import ToolExecutorBase
|
||||
from agentpress.tool import ToolResult
|
||||
|
||||
# --- Standard Tool Executor Implementation ---
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from agentpress.base_processors import ToolParserBase
|
||||
from agentpress.processor.base_processors import ToolParserBase
|
||||
|
||||
# --- Standard Tool Parser Implementation ---
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from agentpress.base_processors import ResultsAdderBase
|
||||
from agentpress.processor.base_processors import ResultsAdderBase
|
||||
|
||||
class XMLResultsAdder(ResultsAdderBase):
|
||||
"""XML-specific implementation for handling tool results and message processing.
|
||||
|
@ -79,7 +79,7 @@ class XMLResultsAdder(ResultsAdderBase):
|
|||
"""
|
||||
try:
|
||||
# Get the original tool call to find the root tag
|
||||
messages = await self.list_messages(thread_id)
|
||||
messages = await self.get_messages(thread_id)
|
||||
assistant_msg = next((msg for msg in reversed(messages)
|
||||
if msg['role'] == 'assistant'), None)
|
||||
|
|
@ -9,7 +9,7 @@ from typing import List, Dict, Any, Set, Callable, Optional
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from agentpress.base_processors import ToolExecutorBase
|
||||
from agentpress.processor.base_processors import ToolExecutorBase
|
||||
from agentpress.tool import ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
|
|
@ -7,7 +7,7 @@ complete and streaming responses with robust XML parsing and validation capabili
|
|||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from agentpress.base_processors import ToolParserBase
|
||||
from agentpress.processor.base_processors import ToolParserBase
|
||||
import json
|
||||
import re
|
||||
from agentpress.tool_registry import ToolRegistry
|
|
@ -1,76 +1,100 @@
|
|||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional, List, Dict, Union, AsyncGenerator
|
||||
from asyncio import Lock
|
||||
from contextlib import asynccontextmanager
|
||||
import uuid
|
||||
from agentpress.db_connection import DBConnection
|
||||
import asyncio
|
||||
|
||||
class StateManager:
|
||||
"""
|
||||
Manages persistent state storage for AgentPress components.
|
||||
|
||||
The StateManager provides thread-safe access to a JSON-based state store,
|
||||
allowing components to save and retrieve data across sessions. It handles
|
||||
concurrent access using asyncio locks and provides atomic operations for
|
||||
state modifications.
|
||||
The StateManager provides thread-safe access to a SQLite-based state store,
|
||||
allowing components to save and retrieve data across sessions. Each store
|
||||
has a unique ID and contains multiple key-value pairs in a single JSON object.
|
||||
|
||||
Attributes:
|
||||
lock (Lock): Asyncio lock for thread-safe state access
|
||||
store_file (str): Path to the JSON file storing the state
|
||||
db (DBConnection): Database connection manager
|
||||
store_id (str): Unique identifier for this state store
|
||||
"""
|
||||
|
||||
def __init__(self, store_file: str = "state.json"):
|
||||
def __init__(self, store_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize StateManager with custom store file name.
|
||||
Initialize StateManager with optional store ID.
|
||||
|
||||
Args:
|
||||
store_file (str): Path to the JSON file to store state.
|
||||
Defaults to "state.json" in the current directory.
|
||||
store_id (str, optional): Unique identifier for the store. If None, creates new.
|
||||
"""
|
||||
self.lock = Lock()
|
||||
self.store_file = store_file
|
||||
logging.info(f"StateManager initialized with store file: {store_file}")
|
||||
self.db = DBConnection()
|
||||
self.store_id = store_id or str(uuid.uuid4())
|
||||
logging.info(f"StateManager initialized with store_id: {self.store_id}")
|
||||
asyncio.create_task(self._ensure_store_exists())
|
||||
|
||||
@classmethod
|
||||
async def create_store(cls) -> str:
|
||||
"""Create a new state store and return its ID."""
|
||||
store_id = str(uuid.uuid4())
|
||||
manager = cls(store_id)
|
||||
await manager._ensure_store_exists()
|
||||
return store_id
|
||||
|
||||
async def _ensure_store_exists(self):
|
||||
"""Ensure store exists in database."""
|
||||
async with self.db.transaction() as conn:
|
||||
await conn.execute("""
|
||||
INSERT OR IGNORE INTO state_stores (store_id, store_data)
|
||||
VALUES (?, ?)
|
||||
""", (self.store_id, json.dumps({})))
|
||||
|
||||
@asynccontextmanager
|
||||
async def store_scope(self):
|
||||
"""
|
||||
Context manager for atomic state operations.
|
||||
|
||||
Provides thread-safe access to the state store, handling file I/O
|
||||
and ensuring proper cleanup. Automatically loads the current state
|
||||
and saves changes when the context exits.
|
||||
Provides thread-safe access to the state store, handling database
|
||||
operations and ensuring proper cleanup.
|
||||
|
||||
Yields:
|
||||
dict: The current state store contents
|
||||
|
||||
Raises:
|
||||
Exception: If there are errors reading from or writing to the store file
|
||||
Exception: If there are errors with database operations
|
||||
"""
|
||||
try:
|
||||
# Read current state
|
||||
if os.path.exists(self.store_file):
|
||||
with open(self.store_file, 'r') as f:
|
||||
store = json.load(f)
|
||||
else:
|
||||
store = {}
|
||||
|
||||
yield store
|
||||
|
||||
# Write updated state
|
||||
with open(self.store_file, 'w') as f:
|
||||
json.dump(store, f, indent=2)
|
||||
logging.debug("Store saved successfully")
|
||||
except Exception as e:
|
||||
logging.error("Error in store operation", exc_info=True)
|
||||
raise
|
||||
async with self.lock:
|
||||
try:
|
||||
async with self.db.transaction() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT store_data FROM state_stores WHERE store_id = ?",
|
||||
(self.store_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
store = json.loads(row[0]) if row else {}
|
||||
|
||||
yield store
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE state_stores
|
||||
SET store_data = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE store_id = ?
|
||||
""",
|
||||
(json.dumps(store), self.store_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error in store operation", exc_info=True)
|
||||
raise
|
||||
|
||||
async def set(self, key: str, data: Any):
|
||||
async def set(self, key: str, data: Any) -> Any:
|
||||
"""
|
||||
Store any JSON-serializable data with a simple key.
|
||||
Store any JSON-serializable data with a key.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
data (Any): Any JSON-serializable data (dict, list, str, int, bool, etc)
|
||||
data (Any): Any JSON-serializable data
|
||||
|
||||
Returns:
|
||||
Any: The stored data
|
||||
|
@ -78,17 +102,12 @@ class StateManager:
|
|||
Raises:
|
||||
Exception: If there are errors during storage operation
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
try:
|
||||
store[key] = data # Will be JSON serialized when written to file
|
||||
logging.info(f'Updated store key: {key}')
|
||||
return data
|
||||
except Exception as e:
|
||||
logging.error(f'Error in set: {str(e)}')
|
||||
raise
|
||||
async with self.store_scope() as store:
|
||||
store[key] = data
|
||||
logging.info(f'Updated store key: {key}')
|
||||
return data
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get data for a key.
|
||||
|
||||
|
@ -97,9 +116,6 @@ class StateManager:
|
|||
|
||||
Returns:
|
||||
Any: The stored data for the key, or None if key not found
|
||||
|
||||
Note:
|
||||
This operation is read-only and doesn't require locking
|
||||
"""
|
||||
async with self.store_scope() as store:
|
||||
if key in store:
|
||||
|
@ -115,17 +131,31 @@ class StateManager:
|
|||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
|
||||
Note:
|
||||
No error is raised if the key doesn't exist
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
if key in store:
|
||||
del store[key]
|
||||
logging.info(f"Deleted key: {key}")
|
||||
else:
|
||||
logging.info(f"Key not found for deletion: {key}")
|
||||
async with self.store_scope() as store:
|
||||
if key in store:
|
||||
del store[key]
|
||||
logging.info(f"Deleted key: {key}")
|
||||
|
||||
async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]:
|
||||
"""Update existing data for a key by merging dictionaries."""
|
||||
async with self.store_scope() as store:
|
||||
if key in store and isinstance(store[key], dict):
|
||||
store[key].update(data)
|
||||
logging.info(f'Updated store key: {key}')
|
||||
return store[key]
|
||||
return None
|
||||
|
||||
async def append(self, key: str, item: Any) -> Optional[List[Any]]:
|
||||
"""Append an item to a list stored at key."""
|
||||
async with self.store_scope() as store:
|
||||
if key not in store:
|
||||
store[key] = []
|
||||
if isinstance(store[key], list):
|
||||
store[key].append(item)
|
||||
logging.info(f'Appended to key: {key}')
|
||||
return store[key]
|
||||
return None
|
||||
|
||||
async def export_store(self) -> dict:
|
||||
"""
|
||||
|
@ -133,9 +163,6 @@ class StateManager:
|
|||
|
||||
Returns:
|
||||
dict: Complete contents of the state store
|
||||
|
||||
Note:
|
||||
This operation is read-only and returns a copy of the store
|
||||
"""
|
||||
async with self.store_scope() as store:
|
||||
logging.info(f"Store content: {store}")
|
||||
|
@ -148,7 +175,29 @@ class StateManager:
|
|||
Removes all data from the store, resetting it to an empty state.
|
||||
This operation is atomic and thread-safe.
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
store.clear()
|
||||
logging.info("Cleared store")
|
||||
async with self.store_scope() as store:
|
||||
store.clear()
|
||||
logging.info("Cleared store")
|
||||
|
||||
@classmethod
|
||||
async def list_stores(cls) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all available state stores.
|
||||
|
||||
Returns:
|
||||
List of store information including IDs and timestamps
|
||||
"""
|
||||
db = DBConnection()
|
||||
async with db.transaction() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT store_id, created_at, updated_at FROM state_stores ORDER BY updated_at DESC"
|
||||
) as cursor:
|
||||
stores = [
|
||||
{
|
||||
"store_id": row[0],
|
||||
"created_at": row[1],
|
||||
"updated_at": row[2]
|
||||
}
|
||||
for row in await cursor.fetchall()
|
||||
]
|
||||
return stores
|
||||
|
|
|
@ -11,21 +11,22 @@ This module provides comprehensive conversation management, including:
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator
|
||||
from agentpress.llm import make_llm_api_call
|
||||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.llm_response_processor import LLMResponseProcessor
|
||||
from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
||||
from agentpress.processor.llm_response_processor import LLMResponseProcessor
|
||||
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
||||
from agentpress.db_connection import DBConnection
|
||||
|
||||
from agentpress.xml_tool_parser import XMLToolParser
|
||||
from agentpress.xml_tool_executor import XMLToolExecutor
|
||||
from agentpress.xml_results_adder import XMLResultsAdder
|
||||
from agentpress.standard_tool_parser import StandardToolParser
|
||||
from agentpress.standard_tool_executor import StandardToolExecutor
|
||||
from agentpress.standard_results_adder import StandardResultsAdder
|
||||
from agentpress.processor.xml.xml_tool_parser import XMLToolParser
|
||||
from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
|
||||
from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
|
||||
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
|
||||
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
|
||||
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
|
||||
|
||||
class ThreadManager:
|
||||
"""Manages conversation threads with LLM models and tool execution.
|
||||
|
@ -33,204 +34,163 @@ class ThreadManager:
|
|||
Provides comprehensive conversation management, handling message threading,
|
||||
tool registration, and LLM interactions with support for both standard and
|
||||
XML-based tool execution patterns.
|
||||
|
||||
Attributes:
|
||||
threads_dir (str): Directory for storing thread files
|
||||
tool_registry (ToolRegistry): Registry for managing available tools
|
||||
|
||||
Methods:
|
||||
add_tool: Register a tool with optional function filtering
|
||||
create_thread: Create a new conversation thread
|
||||
add_message: Add a message to a thread
|
||||
list_messages: Retrieve messages from a thread
|
||||
run_thread: Execute a conversation thread with LLM
|
||||
"""
|
||||
|
||||
def __init__(self, threads_dir: str = "threads"):
|
||||
"""Initialize ThreadManager.
|
||||
|
||||
Args:
|
||||
threads_dir: Directory to store thread files
|
||||
|
||||
Notes:
|
||||
Creates the threads directory if it doesn't exist
|
||||
"""
|
||||
self.threads_dir = threads_dir
|
||||
def __init__(self):
|
||||
"""Initialize ThreadManager."""
|
||||
self.db = DBConnection()
|
||||
self.tool_registry = ToolRegistry()
|
||||
os.makedirs(self.threads_dir, exist_ok=True)
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
"""Add a tool to the ThreadManager.
|
||||
|
||||
Args:
|
||||
tool_class: The tool class to register
|
||||
function_names: Optional list of specific functions to register
|
||||
**kwargs: Additional arguments passed to tool initialization
|
||||
|
||||
Notes:
|
||||
- If function_names is None, all functions are registered
|
||||
- Tool instances are created with provided kwargs
|
||||
"""
|
||||
"""Add a tool to the ThreadManager."""
|
||||
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
||||
|
||||
async def create_thread(self) -> str:
|
||||
"""Create a new conversation thread.
|
||||
|
||||
Returns:
|
||||
str: Unique thread ID for the created thread
|
||||
|
||||
Raises:
|
||||
IOError: If thread file creation fails
|
||||
|
||||
Notes:
|
||||
Creates a new thread file with an empty messages list
|
||||
"""
|
||||
"""Create a new conversation thread."""
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
with open(thread_path, 'w') as f:
|
||||
json.dump({"messages": []}, f)
|
||||
await self.db.execute(
|
||||
"INSERT INTO threads (thread_id, messages) VALUES (?, ?)",
|
||||
(thread_id, json.dumps([]))
|
||||
)
|
||||
return thread_id
|
||||
|
||||
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
||||
"""Add a message to an existing thread.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the target thread
|
||||
message_data: Message content and metadata
|
||||
images: Optional list of image data dictionaries
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If thread doesn't exist
|
||||
Exception: For other operation failures
|
||||
|
||||
Notes:
|
||||
- Handles cleanup of incomplete tool calls
|
||||
- Supports both text and image content
|
||||
- Converts ToolResult instances to strings
|
||||
"""
|
||||
"""Add a message to an existing thread."""
|
||||
logging.info(f"Adding message to thread {thread_id} with images: {images}")
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
|
||||
try:
|
||||
with open(thread_path, 'r') as f:
|
||||
thread_data = json.load(f)
|
||||
|
||||
messages = thread_data["messages"]
|
||||
|
||||
# Handle cleanup of incomplete tool calls
|
||||
if message_data['role'] == 'user':
|
||||
last_assistant_index = next((i for i in reversed(range(len(messages)))
|
||||
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
||||
|
||||
if last_assistant_index is not None:
|
||||
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
|
||||
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
|
||||
if msg['role'] == 'tool')
|
||||
async with self.db.transaction() as conn:
|
||||
# Handle cleanup of incomplete tool calls
|
||||
if message_data['role'] == 'user':
|
||||
messages = await self.get_messages(thread_id)
|
||||
last_assistant_index = next((i for i in reversed(range(len(messages)))
|
||||
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
||||
|
||||
if tool_call_count != tool_response_count:
|
||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||
if last_assistant_index is not None:
|
||||
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
|
||||
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
|
||||
if msg['role'] == 'tool')
|
||||
|
||||
if tool_call_count != tool_response_count:
|
||||
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||
|
||||
# Convert ToolResult instances to strings
|
||||
for key, value in message_data.items():
|
||||
if isinstance(value, ToolResult):
|
||||
message_data[key] = str(value)
|
||||
# Convert ToolResult instances to strings
|
||||
for key, value in message_data.items():
|
||||
if isinstance(value, ToolResult):
|
||||
message_data[key] = str(value)
|
||||
|
||||
# Handle image attachments
|
||||
if images:
|
||||
if isinstance(message_data['content'], str):
|
||||
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
||||
elif not isinstance(message_data['content'], list):
|
||||
message_data['content'] = []
|
||||
# Handle image attachments
|
||||
if images:
|
||||
if isinstance(message_data['content'], str):
|
||||
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
||||
elif not isinstance(message_data['content'], list):
|
||||
message_data['content'] = []
|
||||
|
||||
for image in images:
|
||||
image_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{image['content_type']};base64,{image['base64']}",
|
||||
"detail": "high"
|
||||
for image in images:
|
||||
image_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{image['content_type']};base64,{image['base64']}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
}
|
||||
message_data['content'].append(image_content)
|
||||
message_data['content'].append(image_content)
|
||||
|
||||
messages.append(message_data)
|
||||
thread_data["messages"] = messages
|
||||
# Get current messages
|
||||
row = await self.db.fetch_one(
|
||||
"SELECT messages FROM threads WHERE thread_id = ?",
|
||||
(thread_id,)
|
||||
)
|
||||
if not row:
|
||||
raise ValueError(f"Thread {thread_id} not found")
|
||||
|
||||
messages = json.loads(row[0])
|
||||
messages.append(message_data)
|
||||
|
||||
# Update thread
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE threads
|
||||
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE thread_id = ?
|
||||
""",
|
||||
(json.dumps(messages), thread_id)
|
||||
)
|
||||
|
||||
logging.info(f"Message added to thread {thread_id}: {message_data}")
|
||||
|
||||
with open(thread_path, 'w') as f:
|
||||
json.dump(thread_data, f)
|
||||
|
||||
logging.info(f"Message added to thread {thread_id}: {message_data}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add message to thread {thread_id}: {e}")
|
||||
raise e
|
||||
|
||||
async def list_messages(
|
||||
async def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
hide_tool_msgs: bool = False,
|
||||
only_latest_assistant: bool = False,
|
||||
thread_id: str,
|
||||
hide_tool_msgs: bool = False,
|
||||
only_latest_assistant: bool = False,
|
||||
regular_list: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Retrieve messages from a thread with optional filtering.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to retrieve messages from
|
||||
hide_tool_msgs: If True, excludes tool messages and tool calls
|
||||
only_latest_assistant: If True, returns only the most recent assistant message
|
||||
regular_list: If True, only includes standard message types
|
||||
|
||||
Returns:
|
||||
List of messages matching the filter criteria
|
||||
|
||||
Notes:
|
||||
- Returns empty list if thread doesn't exist
|
||||
- Filters can be combined for different views of the conversation
|
||||
"""
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
|
||||
try:
|
||||
with open(thread_path, 'r') as f:
|
||||
thread_data = json.load(f)
|
||||
messages = thread_data["messages"]
|
||||
|
||||
if only_latest_assistant:
|
||||
for msg in reversed(messages):
|
||||
if msg.get('role') == 'assistant':
|
||||
return [msg]
|
||||
return []
|
||||
|
||||
filtered_messages = messages
|
||||
|
||||
if hide_tool_msgs:
|
||||
filtered_messages = [
|
||||
{k: v for k, v in msg.items() if k != 'tool_calls'}
|
||||
for msg in filtered_messages
|
||||
if msg.get('role') != 'tool'
|
||||
]
|
||||
|
||||
if regular_list:
|
||||
filtered_messages = [
|
||||
msg for msg in filtered_messages
|
||||
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
|
||||
]
|
||||
|
||||
return filtered_messages
|
||||
except FileNotFoundError:
|
||||
"""Retrieve messages from a thread with optional filtering."""
|
||||
row = await self.db.fetch_one(
|
||||
"SELECT messages FROM threads WHERE thread_id = ?",
|
||||
(thread_id,)
|
||||
)
|
||||
if not row:
|
||||
return []
|
||||
|
||||
messages = json.loads(row[0])
|
||||
|
||||
if only_latest_assistant:
|
||||
for msg in reversed(messages):
|
||||
if msg.get('role') == 'assistant':
|
||||
return [msg]
|
||||
return []
|
||||
|
||||
if hide_tool_msgs:
|
||||
messages = [
|
||||
{k: v for k, v in msg.items() if k != 'tool_calls'}
|
||||
for msg in messages
|
||||
if msg.get('role') != 'tool'
|
||||
]
|
||||
|
||||
if regular_list:
|
||||
messages = [
|
||||
msg for msg in messages
|
||||
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
|
||||
"""Update an existing message in the thread."""
|
||||
async with self.db.transaction() as conn:
|
||||
row = await self.db.fetch_one(
|
||||
"SELECT messages FROM threads WHERE thread_id = ?",
|
||||
(thread_id,)
|
||||
)
|
||||
if not row:
|
||||
return
|
||||
|
||||
messages = json.loads(row[0])
|
||||
|
||||
# Find and update the last assistant message
|
||||
for i in reversed(range(len(messages))):
|
||||
if messages[i].get('role') == 'assistant':
|
||||
messages[i] = message
|
||||
break
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE threads
|
||||
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE thread_id = ?
|
||||
""",
|
||||
(json.dumps(messages), thread_id)
|
||||
)
|
||||
|
||||
async def cleanup_incomplete_tool_calls(self, thread_id: str):
|
||||
"""Clean up incomplete tool calls in a thread.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to clean up
|
||||
|
||||
Returns:
|
||||
bool: True if cleanup was performed, False otherwise
|
||||
|
||||
Notes:
|
||||
- Adds failure results for incomplete tool calls
|
||||
- Maintains thread consistency after interruptions
|
||||
"""
|
||||
messages = await self.list_messages(thread_id)
|
||||
"""Clean up incomplete tool calls in a thread."""
|
||||
messages = await self.get_messages(thread_id)
|
||||
last_assistant_message = next((m for m in reversed(messages)
|
||||
if m['role'] == 'assistant' and 'tool_calls' in m), None)
|
||||
|
||||
|
@ -253,10 +213,15 @@ class ThreadManager:
|
|||
assistant_index = messages.index(last_assistant_message)
|
||||
messages[assistant_index+1:assistant_index+1] = failed_tool_results
|
||||
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
with open(thread_path, 'w') as f:
|
||||
json.dump({"messages": messages}, f)
|
||||
|
||||
async with self.db.transaction() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE threads
|
||||
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE thread_id = ?
|
||||
""",
|
||||
(json.dumps(messages), thread_id)
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -326,7 +291,7 @@ class ThreadManager:
|
|||
results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self)
|
||||
|
||||
try:
|
||||
messages = await self.list_messages(thread_id)
|
||||
messages = await self.get_messages(thread_id)
|
||||
prepared_messages = [system_message] + messages
|
||||
if temporary_message:
|
||||
prepared_messages.append(temporary_message)
|
||||
|
@ -345,9 +310,8 @@ class ThreadManager:
|
|||
available_functions=available_functions,
|
||||
add_message_callback=self.add_message,
|
||||
update_message_callback=self._update_message,
|
||||
list_messages_callback=self.list_messages,
|
||||
get_messages_callback=self.get_messages,
|
||||
parallel_tool_execution=parallel_tool_execution,
|
||||
threads_dir=self.threads_dir,
|
||||
tool_parser=tool_parser,
|
||||
tool_executor=tool_executor,
|
||||
results_adder=results_adder
|
||||
|
@ -405,25 +369,6 @@ class ThreadManager:
|
|||
stream=stream
|
||||
)
|
||||
|
||||
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
|
||||
"""Update an existing message in the thread."""
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
try:
|
||||
with open(thread_path, 'r') as f:
|
||||
thread_data = json.load(f)
|
||||
|
||||
# Find and update the last assistant message
|
||||
for i in reversed(range(len(thread_data["messages"]))):
|
||||
if thread_data["messages"][i]["role"] == "assistant":
|
||||
thread_data["messages"][i] = message
|
||||
break
|
||||
|
||||
with open(thread_path, 'w') as f:
|
||||
json.dump(thread_data, f)
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating message in thread {thread_id}: {e}")
|
||||
raise e
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
from agentpress.examples.example_agent.tools.files_tool import FilesTool
|
||||
|
@ -503,7 +448,7 @@ if __name__ == "__main__":
|
|||
print("\n✨ Response completed\n")
|
||||
|
||||
# Display final thread state
|
||||
messages = await thread_manager.list_messages(thread_id)
|
||||
messages = await thread_manager.get_messages(thread_id)
|
||||
print("\n📝 Final Thread State:")
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'unknown')
|
||||
|
|
|
@ -1,21 +1,8 @@
|
|||
import streamlit as st
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
def load_thread_files(threads_dir: str):
|
||||
"""Load all thread files from the threads directory."""
|
||||
thread_files = []
|
||||
if os.path.exists(threads_dir):
|
||||
for file in os.listdir(threads_dir):
|
||||
if file.endswith('.json'):
|
||||
thread_files.append(file)
|
||||
return thread_files
|
||||
|
||||
def load_thread_content(thread_file: str, threads_dir: str):
|
||||
"""Load the content of a specific thread file."""
|
||||
with open(os.path.join(threads_dir, thread_file), 'r') as f:
|
||||
return json.load(f)
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.db_connection import DBConnection
|
||||
import asyncio
|
||||
|
||||
def format_message_content(content):
|
||||
"""Format message content handling both string and list formats."""
|
||||
|
@ -31,89 +18,123 @@ def format_message_content(content):
|
|||
return "\n".join(formatted_content)
|
||||
return str(content)
|
||||
|
||||
async def load_threads():
|
||||
"""Load all thread IDs from the database."""
|
||||
db = DBConnection()
|
||||
rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC")
|
||||
return rows
|
||||
|
||||
async def load_thread_content(thread_id: str):
|
||||
"""Load the content of a specific thread from the database."""
|
||||
thread_manager = ThreadManager()
|
||||
return await thread_manager.get_messages(thread_id)
|
||||
|
||||
def render_message(role, content, avatar):
|
||||
"""Render a message with a consistent chat-like style."""
|
||||
# Create columns for avatar and message
|
||||
col1, col2 = st.columns([1, 11])
|
||||
|
||||
# Style based on role
|
||||
if role == "assistant":
|
||||
bgcolor = "rgba(25, 25, 25, 0.05)"
|
||||
elif role == "user":
|
||||
bgcolor = "rgba(25, 120, 180, 0.05)"
|
||||
elif role == "system":
|
||||
bgcolor = "rgba(180, 25, 25, 0.05)"
|
||||
else:
|
||||
bgcolor = "rgba(100, 100, 100, 0.05)"
|
||||
|
||||
# Display avatar in first column
|
||||
with col1:
|
||||
st.markdown(f"<div style='text-align: center; font-size: 24px;'>{avatar}</div>", unsafe_allow_html=True)
|
||||
|
||||
# Display message in second column
|
||||
with col2:
|
||||
st.markdown(
|
||||
f"""
|
||||
<div style='background-color: {bgcolor}; padding: 10px; border-radius: 5px;'>
|
||||
<strong>{role.upper()}</strong><br>
|
||||
{content}
|
||||
</div>
|
||||
""",
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
|
||||
def main():
|
||||
st.title("Thread Viewer")
|
||||
|
||||
# Directory selection in sidebar
|
||||
st.sidebar.title("Configuration")
|
||||
# Initialize thread data in session state
|
||||
if 'threads' not in st.session_state:
|
||||
st.session_state.threads = asyncio.run(load_threads())
|
||||
|
||||
# Initialize session state with default directory
|
||||
if 'threads_dir' not in st.session_state:
|
||||
default_dir = "./threads"
|
||||
if os.path.exists(default_dir):
|
||||
st.session_state.threads_dir = default_dir
|
||||
|
||||
# Use Streamlit's file uploader for directory selection
|
||||
uploaded_dir = st.sidebar.text_input(
|
||||
"Enter threads directory path",
|
||||
value="./threads" if not st.session_state.threads_dir else st.session_state.threads_dir,
|
||||
placeholder="/path/to/threads",
|
||||
help="Enter the full path to your threads directory"
|
||||
# Thread selection in sidebar
|
||||
st.sidebar.title("Select Thread")
|
||||
|
||||
if not st.session_state.threads:
|
||||
st.warning("No threads found in database")
|
||||
return
|
||||
|
||||
# Format thread options with creation date
|
||||
thread_options = {
|
||||
f"{row[0]} ({datetime.fromisoformat(row[1]).strftime('%Y-%m-%d %H:%M')})"
|
||||
: row[0] for row in st.session_state.threads
|
||||
}
|
||||
|
||||
selected_thread_display = st.sidebar.selectbox(
|
||||
"Choose a thread",
|
||||
options=list(thread_options.keys()),
|
||||
)
|
||||
|
||||
# Automatically load directory if it exists
|
||||
if os.path.exists(uploaded_dir):
|
||||
st.session_state.threads_dir = uploaded_dir
|
||||
else:
|
||||
st.sidebar.error("Directory not found!")
|
||||
|
||||
if st.session_state.threads_dir:
|
||||
st.sidebar.success(f"Selected directory: {st.session_state.threads_dir}")
|
||||
threads_dir = st.session_state.threads_dir
|
||||
if selected_thread_display:
|
||||
# Get the actual thread ID from the display string
|
||||
selected_thread_id = thread_options[selected_thread_display]
|
||||
|
||||
# Thread selection
|
||||
st.sidebar.title("Select Thread")
|
||||
thread_files = load_thread_files(threads_dir)
|
||||
# Display thread ID in sidebar
|
||||
st.sidebar.text(f"Thread ID: {selected_thread_id}")
|
||||
|
||||
if not thread_files:
|
||||
st.warning(f"No thread files found in '{threads_dir}'")
|
||||
return
|
||||
# Add refresh button
|
||||
if st.sidebar.button("🔄 Refresh Thread"):
|
||||
st.session_state.threads = asyncio.run(load_threads())
|
||||
st.experimental_rerun()
|
||||
|
||||
selected_thread = st.sidebar.selectbox(
|
||||
"Choose a thread file",
|
||||
thread_files,
|
||||
format_func=lambda x: f"Thread: {x.replace('.json', '')}"
|
||||
)
|
||||
# Load and display messages
|
||||
messages = asyncio.run(load_thread_content(selected_thread_id))
|
||||
|
||||
if selected_thread:
|
||||
thread_data = load_thread_content(selected_thread, threads_dir)
|
||||
messages = thread_data.get("messages", [])
|
||||
# Display messages in chat-like interface
|
||||
for message in messages:
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
|
||||
# Display thread ID in sidebar
|
||||
st.sidebar.text(f"Thread ID: {selected_thread.replace('.json', '')}")
|
||||
# Determine avatar based on role
|
||||
if role == "assistant":
|
||||
avatar = "🤖"
|
||||
elif role == "user":
|
||||
avatar = "👤"
|
||||
elif role == "system":
|
||||
avatar = "⚙️"
|
||||
elif role == "tool":
|
||||
avatar = "🔧"
|
||||
else:
|
||||
avatar = "❓"
|
||||
|
||||
# Display messages in chat-like interface
|
||||
for message in messages:
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
|
||||
# Determine avatar based on role
|
||||
if role == "assistant":
|
||||
avatar = "🤖"
|
||||
elif role == "user":
|
||||
avatar = "👤"
|
||||
elif role == "system":
|
||||
avatar = "⚙️"
|
||||
elif role == "tool":
|
||||
avatar = "🔧"
|
||||
else:
|
||||
avatar = "❓"
|
||||
|
||||
# Format the message container
|
||||
with st.chat_message(role, avatar=avatar):
|
||||
formatted_content = format_message_content(content)
|
||||
st.markdown(formatted_content)
|
||||
|
||||
if "tool_calls" in message:
|
||||
st.markdown("**Tool Calls:**")
|
||||
for tool_call in message["tool_calls"]:
|
||||
st.code(
|
||||
f"Function: {tool_call['function']['name']}\n"
|
||||
f"Arguments: {tool_call['function']['arguments']}",
|
||||
language="json"
|
||||
)
|
||||
else:
|
||||
st.sidebar.warning("Please enter and load a threads directory")
|
||||
# Format the content
|
||||
formatted_content = format_message_content(content)
|
||||
|
||||
# Render the message
|
||||
render_message(role, formatted_content, avatar)
|
||||
|
||||
# Display tool calls if present
|
||||
if "tool_calls" in message:
|
||||
with st.expander("🛠️ Tool Calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
st.code(
|
||||
f"Function: {tool_call['function']['name']}\n"
|
||||
f"Arguments: {tool_call['function']['arguments']}",
|
||||
language="json"
|
||||
)
|
||||
|
||||
# Add some spacing between messages
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -160,27 +160,44 @@ files = [
|
|||
frozenlist = ">=1.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "altair"
|
||||
version = "5.4.1"
|
||||
description = "Vega-Altair: A declarative statistical visualization library for Python."
|
||||
name = "aiosqlite"
|
||||
version = "0.20.0"
|
||||
description = "asyncio bridge to the standard sqlite3 module"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "altair-5.4.1-py3-none-any.whl", hash = "sha256:0fb130b8297a569d08991fb6fe763582e7569f8a04643bbd9212436e3be04aef"},
|
||||
{file = "altair-5.4.1.tar.gz", hash = "sha256:0ce8c2e66546cb327e5f2d7572ec0e7c6feece816203215613962f0ec1d76a82"},
|
||||
{file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"},
|
||||
{file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jinja2 = "*"
|
||||
jsonschema = ">=3.0"
|
||||
narwhals = ">=1.5.2"
|
||||
packaging = "*"
|
||||
typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""}
|
||||
typing_extensions = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["altair-tiles (>=0.3.0)", "anywidget (>=0.9.0)", "numpy", "pandas (>=0.25.3)", "pyarrow (>=11)", "vega-datasets (>=0.9.0)", "vegafusion[embed] (>=1.6.6)", "vl-convert-python (>=1.6.0)"]
|
||||
dev = ["geopandas", "hatch", "ibis-framework[polars]", "ipython[kernel]", "mistune", "mypy", "pandas (>=0.25.3)", "pandas-stubs", "polars (>=0.20.3)", "pytest", "pytest-cov", "pytest-xdist[psutil] (>=3.5,<4.0)", "ruff (>=0.6.0)", "types-jsonschema", "types-setuptools"]
|
||||
doc = ["docutils", "jinja2", "myst-parser", "numpydoc", "pillow (>=9,<10)", "pydata-sphinx-theme (>=0.14.1)", "scipy", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinxext-altair"]
|
||||
dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
|
||||
docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "altair"
|
||||
version = "4.2.2"
|
||||
description = "Altair: A declarative statistical visualization library for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "altair-4.2.2-py3-none-any.whl", hash = "sha256:8b45ebeaf8557f2d760c5c77b79f02ae12aee7c46c27c06014febab6f849bc87"},
|
||||
{file = "altair-4.2.2.tar.gz", hash = "sha256:39399a267c49b30d102c10411e67ab26374156a84b1aeb9fcd15140429ba49c5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
entrypoints = "*"
|
||||
jinja2 = "*"
|
||||
jsonschema = ">=3.0"
|
||||
numpy = "*"
|
||||
pandas = ">=0.18"
|
||||
toolz = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "docutils", "flake8", "ipython", "m2r", "mistune (<2.0.0)", "pytest", "recommonmark", "sphinx", "vega-datasets"]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
|
@ -226,6 +243,19 @@ files = [
|
|||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asyncio"
|
||||
version = "3.4.3"
|
||||
description = "reference implementation of PEP 3156"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "asyncio-3.4.3-cp33-none-win32.whl", hash = "sha256:b62c9157d36187eca799c378e572c969f0da87cd5fc42ca372d92cdb06e7e1de"},
|
||||
{file = "asyncio-3.4.3-cp33-none-win_amd64.whl", hash = "sha256:c46a87b48213d7464f22d9a497b9eef8c1928b68320a2fa94240f969f6fec08c"},
|
||||
{file = "asyncio-3.4.3-py3-none-any.whl", hash = "sha256:c4d18b22701821de07bd6aea8b53d21449ec0ec5680645e5317062ea21817d2d"},
|
||||
{file = "asyncio-3.4.3.tar.gz", hash = "sha256:83360ff8bc97980e4ff25c964c7bd3923d333d177aa4f7fb736b019f26c7cb41"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "24.2.0"
|
||||
|
@ -428,6 +458,17 @@ files = [
|
|||
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "entrypoints"
|
||||
version = "0.4"
|
||||
description = "Discover and load entry points from installed packages."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "entrypoints-0.4-py3-none-any.whl", hash = "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"},
|
||||
{file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
|
@ -1140,25 +1181,6 @@ files = [
|
|||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "narwhals"
|
||||
version = "1.12.1"
|
||||
description = "Extremely lightweight compatibility layer between dataframe libraries"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "narwhals-1.12.1-py3-none-any.whl", hash = "sha256:e251cb5fe4cabdcabb847d359f5de2b81df773df47e46f858fd5570c936919c4"},
|
||||
{file = "narwhals-1.12.1.tar.gz", hash = "sha256:65ff0d1e8b509df8b52b395e8d5fe96751a68657bdabf0f3057a970ec2cd1809"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
cudf = ["cudf (>=23.08.00)"]
|
||||
dask = ["dask[dataframe] (>=2024.7)"]
|
||||
modin = ["modin"]
|
||||
pandas = ["pandas (>=0.25.3)"]
|
||||
polars = ["polars (>=0.20.3)"]
|
||||
pyarrow = ["pyarrow (>=11.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "2.0.2"
|
||||
|
@ -1301,9 +1323,9 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
@ -1690,8 +1712,8 @@ files = [
|
|||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.23.4"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -2612,6 +2634,17 @@ files = [
|
|||
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toolz"
|
||||
version = "1.0.0"
|
||||
description = "List processing tools and functional utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"},
|
||||
{file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.4.1"
|
||||
|
@ -2893,4 +2926,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "d4c229cc0ecd64741dcf73186dce1c90100230e57acca3aa589e66dbb3052e9d"
|
||||
content-hash = "32b3aefcb3a32a251cd1de7abaa721729c4e2ce2640625a80fcf51a0e3d5da2f"
|
||||
|
|
|
@ -30,6 +30,9 @@ packaging = "^23.2"
|
|||
setuptools = "^75.3.0"
|
||||
pytest = "^8.3.3"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
aiosqlite = "^0.20.0"
|
||||
asyncio = "^3.4.3"
|
||||
altair = "4.2.2"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agentpress = "agentpress.cli:main"
|
||||
|
|
Loading…
Reference in New Issue