api server, thread ws, api factory

This commit is contained in:
marko-kraemer 2025-02-03 15:04:59 +00:00
parent a455980807
commit e72b15c728
32 changed files with 3071 additions and 1488 deletions

1
agentpress/__init__.py Normal file
View File

@ -0,0 +1 @@
# Empty file to mark as package

View File

@ -8,23 +8,22 @@ This agent can:
- Use either XML or Standard tool calling patterns
"""
import os
import asyncio
import json
from agentpress.thread_manager import ThreadManager
from tools.files_tool import FilesTool
from example.tools.files_tool import FilesTool
from agentpress.state_manager import StateManager
from tools.terminal_tool import TerminalTool
from agentpress.api_factory import register_api_endpoint
from example.tools.terminal_tool import TerminalTool
import logging
from typing import AsyncGenerator, Optional, Dict, Any
import sys
from agentpress.api.api_factory import register_thread_task_api
BASE_SYSTEM_MESSAGE = """
You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
codebase - do not omit previous progress; instead, keep iterating.
Available tools:
- create_file: Create new files with specified content
- delete_file: Remove existing files
@ -69,7 +68,6 @@ Example workspace state for a file:
}
}
Think deeply and step by step.
"""
XML_FORMAT = """
@ -88,87 +86,77 @@ file contents here
<delete-file file_path="path/to/file">
</delete-file>
<stop_session></stop_session>
"""
def get_anthropic_api_key():
"""Get Anthropic API key from environment or prompt user."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
api_key = input("\n🔑 Please enter your Anthropic API key: ").strip()
if not api_key:
print("❌ No API key provided. Please set ANTHROPIC_API_KEY environment variable or enter a key.")
sys.exit(1)
os.environ["ANTHROPIC_API_KEY"] = api_key
return api_key
@register_api_endpoint("/main_agent")
@register_thread_task_api("/agent")
async def run_agent(
thread_id: str,
use_xml: bool = True,
max_iterations: int = 5,
project_description: Optional[str] = None
user_input: Optional[str] = None,
) -> Dict[str, Any]:
"""Run the development agent with specified configuration."""
# Initialize managers
thread_manager = ThreadManager()
await thread_manager.initialize()
"""Run the development agent with specified configuration.
state_manager = StateManager(thread_id)
await state_manager.initialize()
Args:
thread_id (str): The ID of the thread.
max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
user_input (Optional[str], optional): The user input. Defaults to None.
"""
thread_manager = ThreadManager()
state_manager = StateManager(thread_id)
# Register tools
thread_manager.add_tool(FilesTool, thread_id=thread_id)
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
# Add initial project description if provided
if project_description:
if user_input:
await thread_manager.add_message(
thread_id,
{
"role": "user",
"content": project_description
"content": user_input
}
)
# Set up system message with appropriate format
thread_manager.add_tool(FilesTool, thread_id=thread_id)
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
system_message = {
"role": "system",
"content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
"content": BASE_SYSTEM_MESSAGE + XML_FORMAT
}
# Create initial event to track agent loop
await thread_manager.create_event(
thread_id=thread_id,
event_type="agent_loop_started",
content={
"max_iterations": max_iterations,
"use_xml": use_xml,
"project_description": project_description
},
include_in_llm_message_history=False
)
results = []
iteration = 0
while iteration < max_iterations:
iteration += 1
files_tool = FilesTool(thread_id)
await files_tool._init_workspace_state()
files_tool = FilesTool(thread_id=thread_id)
state = await state_manager.get_latest_state()
state_message = {
"role": "user",
"content": f"""
Current development environment workspace state:
<current_workspace_state>
{json.dumps(state, indent=2)}
</current_workspace_state>
"""
state = await state_manager.export_store()
temporary_message_content = f"""
You are tasked to complete the LATEST USER REQUEST!
<latest_user_request>
{user_input}
</latest_user_request>
Current development environment workspace state:
<current_workspace_state>
{json.dumps(state, indent=2) if state else "{}"}
</current_workspace_state>
CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
"""
await thread_manager.add_message(
thread_id=thread_id,
message_data=temporary_message_content,
message_type="temporary_message",
include_in_llm_message_history=False
)
temporary_message = {
"role": "user",
"content": temporary_message_content
}
model_name = "anthropic/claude-3-5-sonnet-latest"
model_name = "anthropic/claude-3-5-sonnet-latest"
response = await thread_manager.run_thread(
thread_id=thread_id,
@ -177,51 +165,57 @@ Current development environment workspace state:
temperature=0.1,
max_tokens=8096,
tool_choice="auto",
temporary_message=state_message,
native_tool_calling=not use_xml,
xml_tool_calling=use_xml,
temporary_message=temporary_message,
native_tool_calling=False,
xml_tool_calling=True,
stream=True,
execute_tools_on_stream=False,
parallel_tool_execution=True
execute_tools_on_stream=True,
parallel_tool_execution=True,
)
# Handle both streaming and regular responses
if hasattr(response, '__aiter__'):
chunks = []
if isinstance(response, AsyncGenerator):
print("\n🤖 Assistant is responding:")
try:
async for chunk in response:
chunks.append(chunk)
if hasattr(chunk.choices[0], 'delta'):
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content is not None:
content = delta.content
print(content, end='', flush=True)
# Check for open_files_in_editor tag and continue if found
if '</open_files_in_editor>' in content:
print("\n📂 Opening files in editor, continuing to next iteration...")
continue
if hasattr(delta, 'tool_calls') and delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.function:
if tool_call.function.name:
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
if tool_call.function.arguments:
print(f" {tool_call.function.arguments}", end='', flush=True)
print("\n✨ Response completed\n")
except Exception as e:
print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
logging.error(f"Error processing stream: {e}")
raise
response = chunks
else:
print("\nNon-streaming response received:", response)
results.append({
"iteration": iteration,
"response": response
})
# # Get latest assistant message and check for stop_session
# latest_msg = await thread_manager.get_llm_history_messages(
# thread_id=thread_id,
# only_latest_assistant=True
# )
# if latest_msg and '</stop_session>' in latest_msg:
# break
# Create iteration completion event
await thread_manager.create_event(
thread_id=thread_id,
event_type="iteration_complete",
content={
"iteration_number": iteration,
"max_iterations": max_iterations,
# "state": state
},
include_in_llm_message_history=False
)
return {
"thread_id": thread_id,
"iterations": results,
}
if __name__ == "__main__":
print("\n🚀 Welcome to AgentPress Web Developer Example!")
get_anthropic_api_key()
print("\n🚀 Welcome to AgentPress!")
project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
if not project_description.strip():
@ -241,10 +235,27 @@ if __name__ == "__main__":
print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
print("Use Ctrl+C to stop the agent at any time.")
async def async_main():
async def test_agent():
thread_manager = ThreadManager()
thread_id = await thread_manager.create_thread()
logging.info(f"Created new thread: {thread_id}")
await run_agent(thread_id, use_xml, project_description=project_description)
try:
result = await run_agent(
thread_id=thread_id,
max_iterations=5,
user_input=project_description,
)
print("\n✅ Test completed successfully!")
except Exception as e:
print(f"\n❌ Test failed: {str(e)}")
raise
asyncio.run(async_main())
try:
asyncio.run(test_agent())
except KeyboardInterrupt:
print("\n⚠️ Test interrupted by user")
except Exception as e:
print(f"\n❌ Test failed with error: {str(e)}")
raise

View File

@ -60,13 +60,8 @@ class FilesTool(Tool):
os.makedirs(self.workspace, exist_ok=True)
if thread_id:
self.state_manager = StateManager(thread_id)
asyncio.create_task(self._init_state())
self.SNIPPET_LINES = 4 # Number of context lines to show around edits
async def _init_state(self):
"""Initialize state manager and workspace state."""
await self.state_manager.initialize()
await self._init_workspace_state()
asyncio.create_task(self._init_workspace_state())
self.SNIPPET_LINES = 4
def _should_exclude_file(self, rel_path: str) -> bool:
"""Check if a file should be excluded based on path, name, or extension"""
@ -264,6 +259,9 @@ class FilesTool(Tool):
new_content = content.replace(old_str, new_str)
full_path.write_text(new_content)
# Update state after file modification
await self._update_workspace_state()
# Show snippet around the edit
replacement_line = content.split(old_str)[0].count('\n')
start_line = max(0, replacement_line - self.SNIPPET_LINES)

View File

@ -2,7 +2,6 @@ import os
import asyncio
import subprocess
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional
class TerminalTool(Tool):
@ -12,24 +11,6 @@ class TerminalTool(Tool):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
if thread_id:
self.state_manager = StateManager(thread_id)
asyncio.create_task(self._init_state())
async def _init_state(self):
"""Initialize state manager."""
await self.state_manager.initialize()
async def _update_command_history(self, command: str, output: str, success: bool):
"""Update command history in state"""
history = await self.state_manager.get("terminal_history") or []
history.append({
"command": command,
"output": output,
"success": success,
"cwd": os.path.relpath(os.getcwd(), self.workspace)
})
await self.state_manager.set("terminal_history", history)
@openapi_schema({
"type": "function",
@ -76,12 +57,6 @@ class TerminalTool(Tool):
error = stderr.decode() if stderr else ""
success = process.returncode == 0
await self._update_command_history(
command=command,
output=output + error,
success=success
)
if success:
return self.success_response({
"output": output,
@ -93,11 +68,6 @@ class TerminalTool(Tool):
return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
except Exception as e:
await self._update_command_history(
command=command,
output=str(e),
success=False
)
return self.fail_response(f"Error executing command: {str(e)}")
finally:
os.chdir(original_dir)

View File

@ -1,253 +0,0 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager
import asyncio
import json
import uvicorn
import logging
import importlib
from agentpress.api_factory import app as api_app, discover_tasks
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global managers
thread_manager: Optional[ThreadManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for FastAPI application."""
# Startup
global thread_manager
thread_manager = ThreadManager()
await thread_manager.initialize()
yield
# Shutdown
# Add any cleanup code here if needed
# Create FastAPI app
app = FastAPI(title="AgentPress API", lifespan=lifespan)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Import and mount the API Factory app
try:
# Run task discovery
discover_tasks()
logger.info("Task discovery completed")
# Mount the API Factory app at /tasks instead of root
app.mount("/tasks", api_app)
logger.info("Mounted API Factory app at /tasks")
except Exception as e:
logger.error(f"Error setting up API Factory: {e}")
raise
# WebSocket connection manager
class WebSocketManager:
"""Manages WebSocket connections for real-time thread updates."""
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, thread_id: str):
"""Connect a WebSocket to a thread."""
await websocket.accept()
if thread_id not in self.active_connections:
self.active_connections[thread_id] = []
self.active_connections[thread_id].append(websocket)
def disconnect(self, websocket: WebSocket, thread_id: str):
"""Disconnect a WebSocket from a thread."""
if thread_id in self.active_connections:
self.active_connections[thread_id].remove(websocket)
if not self.active_connections[thread_id]:
del self.active_connections[thread_id]
async def broadcast_to_thread(self, thread_id: str, message: dict):
"""Broadcast a message to all connections in a thread."""
if thread_id in self.active_connections:
for connection in self.active_connections[thread_id]:
try:
await connection.send_json(message)
except WebSocketDisconnect:
self.disconnect(connection, thread_id)
# Initialize WebSocket manager
ws_manager = WebSocketManager()
# Pydantic models for request/response validation
class EventCreate(BaseModel):
event_type: str
content: Dict[str, Any]
include_in_llm_message_history: bool = False
llm_message: Optional[Dict[str, Any]] = None
class EventUpdate(BaseModel):
content: Optional[Dict[str, Any]] = None
include_in_llm_message_history: Optional[bool] = None
llm_message: Optional[Dict[str, Any]] = None
class ThreadEvents(BaseModel):
only_llm_messages: bool = False
event_types: Optional[List[str]] = None
order_by: str = "created_at"
order: str = "ASC"
# REST API Endpoints
@app.post("/threads", response_model=dict, status_code=201)
async def create_thread():
"""Create a new thread."""
thread_id = await thread_manager.create_thread()
return {"thread_id": thread_id}
@app.delete("/threads/{thread_id}", status_code=204)
async def delete_thread(thread_id: str):
"""Delete a thread and all its events."""
success = await thread_manager.delete_thread(thread_id)
if not success:
raise HTTPException(status_code=404, detail="Thread not found")
return None
@app.post("/threads/{thread_id}/events", response_model=dict, status_code=201)
async def create_event(thread_id: str, event: EventCreate, background_tasks: BackgroundTasks):
"""Create a new event in a thread."""
# First verify thread exists
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
try:
event_id = await thread_manager.create_event(
thread_id=thread_id,
event_type=event.event_type,
content=event.content,
include_in_llm_message_history=event.include_in_llm_message_history,
llm_message=event.llm_message
)
# Broadcast to WebSocket connections
background_tasks.add_task(
ws_manager.broadcast_to_thread,
thread_id,
{"type": "event_created", "event_id": event_id, "event": event.dict()}
)
return {"event_id": event_id}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/threads/{thread_id}/events/{event_id}", status_code=204)
async def delete_event(thread_id: str, event_id: str, background_tasks: BackgroundTasks):
"""Delete a specific event."""
# First verify thread exists
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
# Then verify event exists and belongs to thread
if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
raise HTTPException(status_code=404, detail="Event not found in this thread")
success = await thread_manager.delete_event(event_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete event")
# Broadcast to WebSocket connections
background_tasks.add_task(
ws_manager.broadcast_to_thread,
thread_id,
{"type": "event_deleted", "event_id": event_id}
)
return None
@app.patch("/threads/{thread_id}/events/{event_id}", status_code=200)
async def update_event(thread_id: str, event_id: str, event: EventUpdate, background_tasks: BackgroundTasks):
"""Update an existing event."""
# First verify thread exists
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
# Then verify event exists and belongs to thread
if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
raise HTTPException(status_code=404, detail="Event not found in this thread")
success = await thread_manager.update_event(
event_id=event_id,
thread_id=thread_id,
content=event.content,
include_in_llm_message_history=event.include_in_llm_message_history,
llm_message=event.llm_message
)
if not success:
raise HTTPException(status_code=500, detail="Failed to update event")
# Broadcast to WebSocket connections
background_tasks.add_task(
ws_manager.broadcast_to_thread,
thread_id,
{"type": "event_updated", "event_id": event_id, "updates": event.dict(exclude_unset=True)}
)
return {"status": "success"}
@app.get("/threads/{thread_id}/events")
async def get_thread_events(
thread_id: str,
only_llm_messages: bool = False,
event_types: Optional[List[str]] = None,
order_by: str = "created_at",
order: str = "ASC"
):
"""Get events from a thread with filtering options."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
events = await thread_manager.get_thread_events(
thread_id=thread_id,
only_llm_messages=only_llm_messages,
event_types=event_types,
order_by=order_by,
order=order
)
return {"events": events}
@app.get("/threads/{thread_id}/messages")
async def get_thread_messages(thread_id: str):
"""Get LLM-formatted messages from thread events."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
messages = await thread_manager.get_thread_llm_messages(thread_id)
return {"messages": messages}
# WebSocket Endpoint
@app.websocket("/ws/threads/{thread_id}")
async def websocket_endpoint(websocket: WebSocket, thread_id: str):
"""WebSocket endpoint for real-time thread updates."""
# Verify thread exists before accepting connection
if not await thread_manager.thread_exists(thread_id):
await websocket.close(code=4004, reason="Thread not found")
return
await ws_manager.connect(websocket, thread_id)
try:
while True:
await websocket.receive_json()
except WebSocketDisconnect:
ws_manager.disconnect(websocket, thread_id)
except Exception as e:
await websocket.send_json({"type": "error", "detail": str(e)})
ws_manager.disconnect(websocket, thread_id)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1 @@
# Empty file to mark as package

255
agentpress/api/api.py Normal file
View File

@ -0,0 +1,255 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager
import asyncio
import uvicorn
import logging
from agentpress.api.ws import ws_manager
from agentpress.api.api_factory import (
app as thread_task_app,
register_thread_task_api,
discover_tasks,
thread_manager as task_thread_manager
)
# from agentpress.api_factory import app as api_app, discover_tasks
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global managers
thread_manager: Optional[ThreadManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for FastAPI application."""
# Startup
global thread_manager
thread_manager = ThreadManager()
# Share thread_manager with task API
global task_thread_manager
task_thread_manager = thread_manager
# Wait for DB initialization
db = thread_manager.db
if db._initialization_task:
await db._initialization_task
# Run task discovery during startup
discover_tasks()
yield
# Shutdown
# Add any cleanup code here if needed
# Create FastAPI app
app = FastAPI(title="AgentPress API", lifespan=lifespan)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# # Import and mount the API Factory app
# try:
# # Run task discovery
# # discover_tasks()
# logger.info("Task discovery completed")
# # Mount the API Factory app at /tasks instead of root
# app.mount("/tasks", api_app)
# logger.info("Mounted API Factory app at /tasks")
# except Exception as e:
# logger.error(f"Error setting up API Factory: {e}")
# raise
# Pydantic models for request/response validation
class MessageCreate(BaseModel):
"""Model for creating messages in a thread."""
message_data: Union[str, Dict[str, Any]]
images: Optional[List[Dict[str, Any]]] = None
include_in_llm_message_history: bool = True
message_type: Optional[str] = None
# REST API Endpoints
@app.post("/threads", response_model=dict, status_code=201)
async def create_thread():
"""Create a new thread."""
thread_id = await thread_manager.create_thread()
return {"thread_id": thread_id}
@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201)
async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks):
"""Create a new message in a thread."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
try:
await thread_manager.add_message(
thread_id=thread_id,
message_data=message.message_data,
images=message.images,
include_in_llm_message_history=message.include_in_llm_message_history,
message_type=message.message_type
)
# Broadcast to WebSocket connections
background_tasks.add_task(
ws_manager.broadcast_to_thread,
thread_id,
{"type": "message_created", "message": message.dict()}
)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# TODO: BROKEN FOR SOME REASON RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE
@app.get("/threads/{thread_id}/llm_history_messages")
async def get_thread_llm_messages(
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
):
"""Get messages from a thread with filtering options."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
messages = await thread_manager.get_llm_history_messages(
thread_id=thread_id,
hide_tool_msgs=hide_tool_msgs,
only_latest_assistant=only_latest_assistant,
)
return {"messages": messages}
@app.get("/threads/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
message_types: Optional[List[str]] = None,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
before_timestamp: Optional[str] = None,
after_timestamp: Optional[str] = None,
include_in_llm_message_history: Optional[bool] = None,
order: str = "asc"
):
"""
Get messages from a thread with comprehensive filtering options.
Args:
thread_id: Thread identifier
message_types: Optional list of message types to filter by
limit: Maximum number of messages to return (default: 50)
offset: Number of messages to skip for pagination
before_timestamp: Optional filter for messages before timestamp
after_timestamp: Optional filter for messages after timestamp
include_in_llm_message_history: Optional filter for LLM history inclusion
order: Sort order - "asc" or "desc"
"""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
try:
messages = await thread_manager.get_messages(
thread_id=thread_id,
message_types=message_types,
limit=limit,
offset=offset,
before_timestamp=before_timestamp,
after_timestamp=after_timestamp,
include_in_llm_message_history=include_in_llm_message_history,
order=order
)
return {"messages": messages}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling.
# WebSocket Endpoint
@app.websocket("/threads/{thread_id}")
async def websocket_endpoint(
websocket: WebSocket,
thread_id: str,
message_types: Optional[List[str]] = None,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
before_timestamp: Optional[str] = None,
after_timestamp: Optional[str] = None,
include_in_llm_message_history: Optional[bool] = None,
order: str = "desc"
):
"""
WebSocket endpoint for real-time thread updates with filtering and pagination.
Query Parameters:
message_types: Optional list of message types to filter by
limit: Maximum number of messages to return (default: 50)
offset: Number of messages to skip (for pagination)
before_timestamp: Optional timestamp to filter messages before
after_timestamp: Optional timestamp to filter messages after
include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
order: Sort order - "asc" or "desc" (default: desc)
"""
try:
if not await thread_manager.thread_exists(thread_id):
await websocket.close(code=4004, reason="Thread not found")
return
await ws_manager.connect(websocket, thread_id)
while True:
try:
# Get messages with all filters
result = await thread_manager.get_messages(
thread_id=thread_id,
message_types=message_types,
limit=limit,
offset=offset,
before_timestamp=before_timestamp,
after_timestamp=after_timestamp,
include_in_llm_message_history=include_in_llm_message_history,
order=order
)
# Send messages and pagination info
await websocket.send_json({
"type": "messages",
"data": result
})
# Poll every second
await asyncio.sleep(1)
except WebSocketDisconnect:
ws_manager.disconnect(websocket, thread_id)
break
except Exception as e:
logging.error(f"WebSocket error: {e}")
await websocket.send_json({
"type": "error",
"data": str(e)
})
ws_manager.disconnect(websocket, thread_id)
break
except Exception as e:
logging.error(f"WebSocket connection error: {e}")
try:
await websocket.close(code=1011, reason=str(e))
except:
pass
# Update the mounting of thread_task_app
app.mount("/tasks", thread_task_app)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1,349 @@
"""
Thread Task API Factory for registering and managing thread-associated long-running tasks.
"""
import sys
import os
import inspect
import uuid
import asyncio
import logging
import importlib
from functools import wraps
from typing import Callable, Dict, Any, Optional, List, ForwardRef
from fastapi import FastAPI, HTTPException, Request
from pydantic import create_model
from agentpress.thread_manager import ThreadManager
from contextlib import asynccontextmanager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize managers at module level
thread_manager: Optional[ThreadManager] = None
_decorated_functions: Dict[str, Callable] = {}
_running_tasks: Dict[str, Dict[str, Any]] = {}
def find_project_root():
"""Find the project root by looking for pyproject.toml"""
current = os.path.abspath(os.path.dirname(__file__))
while current != '/':
if os.path.exists(os.path.join(current, 'pyproject.toml')):
return current
current = os.path.dirname(current)
return None
def discover_tasks():
"""
Discover all decorated functions in the project.
Scans from the project root (where pyproject.toml is located).
"""
logger.info("Starting task discovery")
# Find project root
project_root = find_project_root()
logger.info(f"Project root found at: {project_root}")
# Add project root to Python path if not already there
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Walk through all Python files in the project
for root, _, files in os.walk(project_root):
for file in files:
if file.endswith('.py'):
module_path = os.path.join(root, file)
module_name = os.path.relpath(module_path, project_root)
module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
try:
logger.info(f"Attempting to import module: {module_name}")
module = importlib.import_module(module_name)
# Inspect all module members
for name, obj in inspect.getmembers(module):
if inspect.isfunction(obj):
# Check if this function has been decorated with register_thread_task_api
if hasattr(obj, '__closure__') and obj.__closure__:
for cell in obj.__closure__:
if cell.cell_contents in _decorated_functions.values():
path = next(
p for p, f in _decorated_functions.items()
if f == cell.cell_contents
)
_decorated_functions[path] = obj
logger.info(f"Registered function: {obj.__name__} at path: {path}")
except Exception as e:
logger.error(f"Error importing {module_name}: {e}")
logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for FastAPI application."""
global thread_manager
# Initialize ThreadManager if not already initialized
if thread_manager is None:
thread_manager = ThreadManager()
# Wait for DB initialization
if thread_manager.db._initialization_task:
await thread_manager.db._initialization_task
# Run task discovery during startup
discover_tasks()
yield
# Cleanup if needed
# Create FastAPI app
app = FastAPI(
title="Thread Task API",
description="API for managing thread-associated long-running tasks",
openapi_tags=[{
"name": "Generated Thread Tasks",
"description": "Dynamically generated endpoints for thread-associated tasks"
}],
lifespan=lifespan
)
# Add middleware to ensure thread_manager is available
@app.middleware("http")
async def ensure_thread_manager(request: Request, call_next):
"""Ensure thread_manager is initialized before handling requests."""
global thread_manager
if thread_manager is None:
thread_manager = ThreadManager()
if thread_manager.db._initialization_task:
await thread_manager.db._initialization_task
return await call_next(request)
def register_thread_task_api(path: str):
"""
Decorator to register a function as a thread-associated task API endpoint.
The decorated function must have thread_id as its first parameter.
"""
def decorator(func: Callable):
logger.info(f"Registering thread task API endpoint: {path} for function {func.__name__}")
_decorated_functions[path] = func
# Validate that thread_id is the first parameter
params = inspect.signature(func).parameters
if 'thread_id' not in params:
raise ValueError(f"Function {func.__name__} must have thread_id as a parameter")
# Create Pydantic model for function parameters
model_fields = {}
for name, param in params.items():
if name == 'self': # Skip self parameter for methods
continue
annotation = param.annotation
if annotation == inspect.Parameter.empty:
annotation = Any
# Convert string annotations to ForwardRef
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
default = ... if param.default == inspect.Parameter.empty else param.default
model_fields[name] = (annotation, default)
RequestModel = create_model(f'{func.__name__}Request', **model_fields)
# Register the start endpoint
@app.post(
f"{path}/start",
response_model=dict,
summary=f"Start {func.__name__}",
description=f"Start a new {func.__name__} task associated with a thread",
tags=["Generated Thread Tasks"]
)
async def start_task(params: RequestModel):
logger.info(f"Starting task at {path}/start")
# Validate thread exists
if not await thread_manager.thread_exists(params.thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
task_id = str(uuid.uuid4())
kwargs = params.dict()
# Create the task
task = asyncio.create_task(func(**kwargs))
# Store task with thread association
_running_tasks[task_id] = {
"thread_id": params.thread_id,
"task": task,
"status": "running",
"path": path,
"started_at": asyncio.get_event_loop().time()
}
# Add task info to thread messages
await thread_manager.add_message(
thread_id=params.thread_id,
message_data={
"type": "task_started",
"task_id": task_id,
"path": path,
"status": "running"
},
message_type="task_status",
include_in_llm_message_history=False
)
return {"task_id": task_id}
# Register the stop endpoint
@app.post(
f"{path}/stop/{{task_id}}",
response_model=dict,
summary=f"Stop {func.__name__}",
description=f"Stop a running {func.__name__} task",
tags=["Generated Thread Tasks"]
)
async def stop_task(task_id: str):
if task_id not in _running_tasks:
raise HTTPException(status_code=404, detail="Task not found")
task_info = _running_tasks[task_id]
task_info["task"].cancel()
task_info["status"] = "cancelled"
# Update thread with task cancellation
await thread_manager.add_message(
thread_id=task_info["thread_id"],
message_data={
"type": "task_stopped",
"task_id": task_id,
"path": task_info["path"],
"status": "cancelled"
},
message_type="task_status",
include_in_llm_message_history=False
)
return {"status": "stopped"}
# Register the status endpoint
@app.get(
f"{path}/status/{{task_id}}",
response_model=dict,
summary=f"Get {func.__name__} status",
description=f"Get the status of a {func.__name__} task",
tags=["Generated Thread Tasks"]
)
async def get_status(task_id: str):
if task_id not in _running_tasks:
raise HTTPException(status_code=404, detail="Task not found")
task_info = _running_tasks[task_id]
task = task_info["task"]
if task.done():
try:
result = task.result()
status = "completed"
if hasattr(result, '__aiter__'):
status = "streaming"
# Update thread with task completion
await thread_manager.add_message(
thread_id=task_info["thread_id"],
message_data={
"type": "task_completed",
"task_id": task_id,
"path": task_info["path"],
"status": status,
"result": result if status == "completed" else None
},
message_type="task_status",
include_in_llm_message_history=False
)
return {
"status": status,
"result": result if status == "completed" else None
}
except asyncio.CancelledError:
return {"status": "cancelled"}
except Exception as e:
error_status = {
"status": "failed",
"error": str(e)
}
# Update thread with task failure
await thread_manager.add_message(
thread_id=task_info["thread_id"],
message_data={
"type": "task_failed",
"task_id": task_id,
"path": task_info["path"],
"status": "failed",
"error": str(e)
},
message_type="task_status",
include_in_llm_message_history=False
)
return error_status
return {"status": "running"}
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
@app.get("/threads/{thread_id}/tasks")
async def get_thread_tasks(thread_id: str):
"""Get all tasks associated with a thread."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
thread_tasks = {
task_id: {
"path": info["path"],
"status": info["status"],
"started_at": info["started_at"]
}
for task_id, info in _running_tasks.items()
if info["thread_id"] == thread_id
}
return {"tasks": thread_tasks}
@app.delete("/threads/{thread_id}/tasks")
async def stop_thread_tasks(thread_id: str):
"""Stop all tasks associated with a thread."""
if not await thread_manager.thread_exists(thread_id):
raise HTTPException(status_code=404, detail="Thread not found")
stopped_tasks = []
for task_id, info in list(_running_tasks.items()):
if info["thread_id"] == thread_id:
info["task"].cancel()
info["status"] = "cancelled"
stopped_tasks.append(task_id)
# Update thread with task cancellations
if stopped_tasks:
await thread_manager.add_message(
thread_id=thread_id,
message_data={
"type": "tasks_stopped",
"task_ids": stopped_tasks,
"status": "cancelled"
},
message_type="task_status",
include_in_llm_message_history=False
)
return {"stopped_tasks": stopped_tasks}

45
agentpress/api/ws.py Normal file
View File

@ -0,0 +1,45 @@
"""WebSocket management system for real-time updates."""
import logging
from typing import Dict, List
from fastapi import WebSocket, WebSocketDisconnect
class WebSocketManager:
"""Manages WebSocket connections for real-time thread updates."""
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, thread_id: str):
"""Connect a WebSocket to a thread."""
await websocket.accept()
if thread_id not in self.active_connections:
self.active_connections[thread_id] = []
self.active_connections[thread_id].append(websocket)
def disconnect(self, websocket: WebSocket, thread_id: str):
"""Disconnect a WebSocket from a thread."""
if thread_id in self.active_connections:
self.active_connections[thread_id].remove(websocket)
if not self.active_connections[thread_id]:
del self.active_connections[thread_id]
async def broadcast_to_thread(self, thread_id: str, message: dict):
"""Broadcast a message to all connections in a thread."""
if thread_id in self.active_connections:
disconnected = []
for connection in self.active_connections[thread_id]:
try:
await connection.send_json(message)
except WebSocketDisconnect:
disconnected.append(connection)
except Exception as e:
logging.warning(f"Failed to send message to websocket: {e}")
disconnected.append(connection)
# Clean up disconnected connections
for connection in disconnected:
self.disconnect(connection, thread_id)
# Global WebSocket manager instance
ws_manager = WebSocketManager()

View File

@ -1,170 +0,0 @@
"""
API Factory for registering and managing FastAPI endpoints.
"""
import sys
import inspect
import importlib
import pkgutil
import uuid
import asyncio
import logging
import os
from functools import wraps
from typing import Callable, Dict, Any, Optional, List
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import create_model, BaseModel
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
_decorated_functions: Dict[str, Callable] = {}
_running_tasks: Dict[str, asyncio.Task] = {}
def register_api_endpoint(path: str):
"""Decorator to register a function as an API endpoint with task management."""
def decorator(func: Callable):
logger.info(f"Registering API endpoint: {path} for function {func.__name__}")
_decorated_functions[path] = func
# Create Pydantic model for function parameters
params = inspect.signature(func).parameters
model_fields = {
name: (param.annotation if param.annotation != inspect.Parameter.empty else Any, ... if param.default == inspect.Parameter.empty else param.default)
for name, param in params.items()
if name != 'self' # Skip self parameter for methods
}
RequestModel = create_model(f'{func.__name__}Request', **model_fields)
# Register the start endpoint
@app.post(f"{path}/start", response_model=dict)
async def start_task(params: Optional[RequestModel] = None):
logger.info(f"Starting task at {path}/start")
task_id = str(uuid.uuid4())
kwargs = params.dict() if params else {}
_running_tasks[task_id] = asyncio.create_task(func(**kwargs))
return {"task_id": task_id}
# Register the stop endpoint
@app.post(f"{path}/stop/{{task_id}}")
async def stop_task(task_id: str):
if task_id not in _running_tasks:
raise HTTPException(status_code=404, detail="Task not found")
_running_tasks[task_id].cancel()
return {"status": "stopped"}
# Register the status endpoint
@app.get(f"{path}/status/{{task_id}}")
async def get_status(task_id: str):
if task_id not in _running_tasks:
raise HTTPException(status_code=404, detail="Task not found")
task = _running_tasks[task_id]
if task.done():
try:
result = task.result()
# Check if this is a streaming response
if hasattr(result, '__aiter__') or (
isinstance(result, dict) and (
any(hasattr(v, '__aiter__') for v in result.values()) or
# Check for streaming responses in iterations
(
'iterations' in result and
result['iterations'] and
any(hasattr(r.get('response'), '__aiter__') for r in result['iterations'])
)
)
):
return {"status": "streaming"}
return {"status": "completed", "result": result}
except asyncio.CancelledError:
return {"status": "cancelled"}
except Exception as e:
return {"status": "failed", "error": str(e)}
return {"status": "running"}
# Also register a direct endpoint for simple calls
@app.post(path)
async def direct_call(background_tasks: BackgroundTasks, params: Optional[RequestModel] = None):
kwargs = params.dict() if params else {}
task_id = str(uuid.uuid4())
task = asyncio.create_task(func(**kwargs))
_running_tasks[task_id] = task
return {"task_id": task_id}
logger.info(f"Successfully registered all endpoints for {path}")
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
def find_project_root() -> str:
"""Find the project root by looking for pyproject.toml."""
current = os.path.abspath(os.path.dirname(__file__))
while current != '/':
if os.path.exists(os.path.join(current, 'pyproject.toml')):
return current
current = os.path.dirname(current)
return os.path.dirname(__file__) # Fallback to current directory
def discover_tasks():
"""
Discover all decorated functions in the project.
Scans from the project root (where pyproject.toml is located).
"""
logger.info("Starting task discovery")
# Find project root
project_root = find_project_root()
logger.info(f"Project root found at: {project_root}")
# Add project root to Python path if not already there
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Walk through all Python files in the project
for root, _, files in os.walk(project_root):
for file in files:
if file.endswith('.py'):
module_path = os.path.join(root, file)
module_name = os.path.relpath(module_path, project_root)
module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
try:
logger.info(f"Attempting to import module: {module_name}")
module = importlib.import_module(module_name)
# Inspect all module members
for name, obj in inspect.getmembers(module):
if inspect.isfunction(obj):
# Check if this function has been decorated with @task
if any(
path for path, func in _decorated_functions.items()
if func.__name__ == obj.__name__ and func.__module__ == obj.__module__
):
logger.info(f"Found already registered function: {obj.__name__}")
continue
# Check for our decorator in the function's closure
if hasattr(obj, '__closure__') and obj.__closure__:
for cell in obj.__closure__:
if cell.cell_contents in _decorated_functions.values():
# Found a decorated function that wasn't registered
path = next(
p for p, f in _decorated_functions.items()
if f == cell.cell_contents
)
_decorated_functions[path] = obj
logger.info(f"Registered function: {obj.__name__} at path: {path}")
except Exception as e:
logger.error(f"Error importing {module_name}: {e}")
logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
# Auto-discover tasks on import
discover_tasks()

View File

@ -2,99 +2,44 @@ import os
import shutil
import click
import questionary
from typing import List, Dict, Optional, Tuple
from typing import Dict
import time
import pkg_resources
import requests
from packaging import version
import re
MODULES = {
"llm": {
"required": True,
"files": ["llm.py"],
"description": "LLM Interface - Core module for interacting with large language models (OpenAI, Anthropic, 100+ LLMs using the OpenAI Input/Output Format powered by LiteLLM). Handles API calls, response streaming, and model-specific configurations."
},
"tool": {
"required": True,
"files": [
"tool.py",
"tool_registry.py"
],
"description": "Tool System Foundation - Defines the base architecture for creating and managing tools. Includes the tool registry for registering, organizing, and accessing tool functions."
},
"processors": {
"required": True,
"files": [
"processor/base_processors.py",
"processor/llm_response_processor.py",
"processor/standard/standard_tool_parser.py",
"processor/standard/standard_tool_executor.py",
"processor/standard/standard_results_adder.py",
"processor/xml/xml_tool_parser.py",
"processor/xml/xml_tool_executor.py",
"processor/xml/xml_results_adder.py"
],
"description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
},
"thread_management": {
"required": True,
"files": [
"thread_manager.py",
"thread_viewer_ui.py"
],
"description": "Conversation Management System - Handles message threading, conversation history, and provides a UI for viewing conversation threads. Manages the flow of messages between the user, LLM, and tools."
},
"state_management": {
"required": True,
"files": ["state_manager.py"],
"description": "State Persistence System - Provides thread-safe storage and retrieval of conversation state, tool data, and other persistent information. Enables maintaining context across sessions and managing shared state between components."
},
"db_connection": {
"required": True,
"files": ["db_connection.py"],
"description": "Database Connection - Provides a connection to a SQLite database for storing and retrieving conversation state, tool data, and other persistent information."
}
}
PACKAGE_NAME = "agentpress"
PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
STARTER_EXAMPLES = {
"simple_web_dev_example_agent": {
"description": "Interactive web development agent with file and terminal manipulation capabilities. Demonstrates both standard and XML-based tool calling patterns.",
"files": {
"agent.py": "examples/simple_web_dev/agent.py",
"tools/files_tool.py": "examples/simple_web_dev/tools/files_tool.py",
"tools/terminal_tool.py": "examples/simple_web_dev/tools/terminal_tool.py",
".env.example": "examples/.env.example"
"agent.py": "agents/simple_web_dev/agent.py",
"tools/files_tool.py": "agents/simple_web_dev/tools/files_tool.py",
"tools/terminal_tool.py": "agents/simple_web_dev/tools/terminal_tool.py",
".env.example": "agents/.env.example"
}
}
}
PACKAGE_NAME = "agentpress"
PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
def check_for_updates() -> Tuple[Optional[str], Optional[str], bool]:
"""
Check if there's a newer version available on PyPI
Returns: (current_version, latest_version, update_available)
"""
def check_for_updates():
"""Check if there's a newer version available on PyPI"""
try:
current_version = pkg_resources.get_distribution(PACKAGE_NAME).version
response = requests.get(PYPI_URL, timeout=2)
response.raise_for_status() # Raise exception for bad status codes
response.raise_for_status()
latest_version = response.json()["info"]["version"]
# Compare versions properly using packaging.version
current_ver = version.parse(current_version)
latest_ver = version.parse(latest_version)
return current_version, latest_version, latest_ver > current_ver
except requests.RequestException:
# Handle network-related errors silently
return None, None, False
except Exception as e:
# Log other unexpected errors but don't break the CLI
click.echo(f"Warning: Failed to check for updates: {str(e)}", err=True)
return None, None, False
@ -102,7 +47,6 @@ def show_welcome():
"""Display welcome message with ASCII art"""
click.clear()
# Check for updates
current_version, latest_version, update_available = check_for_updates()
click.echo("""
@ -122,16 +66,17 @@ def show_welcome():
time.sleep(1)
def copy_module_files(src_dir: str, dest_dir: str, files: List[str]):
"""Copy module files from package to destination"""
def copy_package_files(src_dir: str, dest_dir: str):
"""Copy all package files except agents folder to destination"""
os.makedirs(dest_dir, exist_ok=True)
with click.progressbar(files, label='Copying files') as file_list:
for file in file_list:
src = os.path.join(src_dir, file)
dst = os.path.join(dest_dir, file)
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy2(src, dst)
def ignore_patterns(path, names):
# Ignore the agents directory and any __pycache__ directories
return [n for n in names if n == 'agents' or n == '__pycache__']
with click.progressbar(length=1, label='Copying files') as bar:
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True, ignore=ignore_patterns)
bar.update(1)
def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
"""Copy example files from package to destination"""
@ -142,19 +87,6 @@ def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
shutil.copy2(src, dst)
click.echo(f" ✓ Created {dest_path}")
def update_file_paths(file_path: str, replacements: Dict[str, str]):
"""Update file paths in the given file"""
with open(file_path, 'r') as f:
content = f.read()
for old, new in replacements.items():
# Escape special characters in the old string
escaped_old = re.escape(old)
content = re.sub(escaped_old, new, content)
with open(file_path, 'w') as f:
f.write(content)
@click.group()
def cli():
"""AgentPress CLI - Initialize your AgentPress modules"""
@ -165,7 +97,6 @@ def init():
"""Initialize AgentPress modules in your project"""
show_welcome()
# Set components directory name to 'agentpress'
components_dir = "agentpress"
if os.path.exists(components_dir):
@ -195,41 +126,20 @@ def init():
# Get package directory
package_dir = os.path.dirname(os.path.abspath(__file__))
# Show all modules status
click.echo("\n🔧 AgentPress Modules Configuration\n")
# Show required modules including state_manager
click.echo("📦 Required Modules (pre-selected):")
required_modules = {name: module for name, module in MODULES.items()
if module["required"] or name == "state_management"}
for name, module in required_modules.items():
click.echo(f"{click.style(name, fg='green')} - {module['description']}")
# Create selections dict with required modules pre-selected
selections = {name: True for name in required_modules.keys()}
click.echo("\n🚀 Setting up your AgentPress...")
time.sleep(0.5)
try:
# Copy selected modules
selected_modules = [name for name, selected in selections.items() if selected]
all_files = []
for module in selected_modules:
all_files.extend(MODULES[module]["files"])
# Create components directory and copy module files
# Create components directory and copy all files except agents folder
components_dir_path = os.path.abspath(components_dir)
copy_module_files(package_dir, components_dir_path, all_files)
copy_package_files(package_dir, components_dir_path)
# Copy example only if a valid example (not None) was selected
# Copy example if selected
if selected_example and selected_example in STARTER_EXAMPLES:
click.echo(f"\n📝 Creating {selected_example}...")
copy_example_files(
package_dir,
os.getcwd(), # Use current working directory
os.getcwd(),
STARTER_EXAMPLES[selected_example]["files"]
)
@ -246,7 +156,6 @@ def init():
click.echo(f"\nRun the example agent:")
click.echo(" python agent.py")
except Exception as e:
click.echo(f"\n❌ Error during setup: {str(e)}", err=True)
return

View File

@ -7,105 +7,122 @@ import logging
from contextlib import asynccontextmanager
import os
import asyncio
import json
class DBConnection:
"""Singleton database connection manager."""
_instance = None
_initialized = False
_db_path = "ap.db"
_db_path = "/Users/markokraemer/Projects/softgen/softgen-core/main.db"
_init_lock = asyncio.Lock()
_initialization_task = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
# Start initialization when instance is first created
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
return cls._instance
async def initialize(self):
"""Initialize the database connection and schema."""
if self._initialized:
def __init__(self):
"""No initialization needed in __init__ as it's handled in __new__"""
pass
@classmethod
async def _initialize(cls):
"""Internal initialization method."""
if cls._initialized:
return
try:
# Ensure the database directory exists
os.makedirs(os.path.dirname(os.path.abspath(self._db_path)), exist_ok=True)
# Initialize database and create schema
async with aiosqlite.connect(self._db_path) as db:
await db.execute("PRAGMA foreign_keys = ON")
async with cls._init_lock:
if cls._initialized: # Double-check after acquiring lock
return
# Create threads table
await db.execute("""
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
try:
async with aiosqlite.connect(cls._db_path) as db:
# Threads table
await db.execute("""
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Messages table
await db.execute("""
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
thread_id TEXT,
type TEXT,
content TEXT,
include_in_llm_message_history BOOLEAN DEFAULT TRUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (thread_id) REFERENCES threads (id)
)
""")
await db.commit()
cls._initialized = True
logging.info("Database schema initialized")
except Exception as e:
logging.error(f"Database initialization error: {e}")
raise
# Create events table
await db.execute("""
CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY,
thread_id TEXT,
type TEXT,
content TEXT,
include_in_llm_message_history INTEGER DEFAULT 0,
llm_message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE
)
""")
@classmethod
def set_db_path(cls, db_path: str):
"""Set custom database path."""
if cls._initialized:
raise RuntimeError("Cannot change database path after initialization")
cls._db_path = db_path
logging.info(f"Updated database path to: {db_path}")
await db.commit()
logging.info("Database initialized successfully")
self._initialized = True
except Exception as e:
logging.error(f"Failed to initialize database: {e}")
raise
@asynccontextmanager
async def connection(self):
"""Get a database connection."""
# Wait for initialization to complete if it hasn't already
if self._initialization_task and not self._initialized:
await self._initialization_task
async with aiosqlite.connect(self._db_path) as conn:
try:
yield conn
except Exception as e:
logging.error(f"Database error: {e}")
raise
@asynccontextmanager
async def transaction(self):
"""Get a database connection with transaction support."""
if not self._initialized:
raise Exception("Database not initialized. Call initialize() first.")
async with aiosqlite.connect(self._db_path) as db:
await db.execute("PRAGMA foreign_keys = ON")
"""Execute operations in a transaction."""
async with self.connection() as db:
try:
yield db
await db.commit()
logging.debug("Transaction committed successfully")
except Exception as e:
await db.rollback()
logging.error(f"Transaction failed, rolling back: {e}")
logging.error(f"Transaction error: {e}")
raise
async def execute(self, query: str, params: tuple = ()):
"""Execute a query and return the cursor."""
async with aiosqlite.connect(self._db_path) as db:
await db.execute("PRAGMA foreign_keys = ON")
return await db.execute(query, params)
async def fetch_all(self, query: str, params: tuple = ()):
"""Execute a query and fetch all results."""
async with aiosqlite.connect(self._db_path) as db:
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(query, params)
return await cursor.fetchall()
"""Execute a single query."""
async with self.connection() as db:
try:
result = await db.execute(query, params)
await db.commit()
return result
except Exception as e:
logging.error(f"Query execution error: {e}")
raise
async def fetch_one(self, query: str, params: tuple = ()):
"""Execute a query and fetch one result."""
async with aiosqlite.connect(self._db_path) as db:
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(query, params)
return await cursor.fetchone()
"""Fetch a single row."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchone()
def _serialize_json(self, data):
"""Serialize data to JSON string."""
return json.dumps(data) if data is not None else None
def _deserialize_json(self, data):
"""Deserialize JSON string to data."""
return json.loads(data) if data is not None else None
async def fetch_all(self, query: str, params: tuple = ()):
"""Fetch all rows."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchall()

View File

@ -1,4 +1,4 @@
from typing import Union, Dict, Any
from typing import Union, Dict, Any, Optional, List
import litellm
import os
import json
@ -11,6 +11,14 @@ OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', None)
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY', None)
GROQ_API_KEY = os.environ.get('GROQ_API_KEY', None)
AGENTOPS_API_KEY = os.environ.get('AGENTOPS_API_KEY', None)
FIREWORKS_API_KEY = os.environ.get('FIREWORKS_AI_API_KEY', None)
DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY', None)
OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', None)
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', None)
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
AWS_REGION_NAME = os.environ.get('AWS_REGION_NAME', None)
if OPENAI_API_KEY:
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
@ -18,6 +26,22 @@ if ANTHROPIC_API_KEY:
os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
if GROQ_API_KEY:
os.environ['GROQ_API_KEY'] = GROQ_API_KEY
if FIREWORKS_API_KEY:
os.environ['FIREWORKS_AI_API_KEY'] = FIREWORKS_API_KEY
if DEEPSEEK_API_KEY:
os.environ['DEEPSEEK_API_KEY'] = DEEPSEEK_API_KEY
if OPENROUTER_API_KEY:
os.environ['OPENROUTER_API_KEY'] = OPENROUTER_API_KEY
if GEMINI_API_KEY:
os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY
# Add AWS environment variables if they exist
if AWS_ACCESS_KEY_ID:
os.environ['AWS_ACCESS_KEY_ID'] = AWS_ACCESS_KEY_ID
if AWS_SECRET_ACCESS_KEY:
os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY
if AWS_REGION_NAME:
os.environ['AWS_REGION_NAME'] = AWS_REGION_NAME
async def make_llm_api_call(
messages: list,
@ -31,7 +55,8 @@ async def make_llm_api_call(
api_base: str = None,
agentops_session: Any = None,
stream: bool = False,
top_p: float = None
top_p: float = None,
stop: Optional[Union[str, List[str]]] = None # Add stop parameter
) -> Union[Dict[str, Any], Any]:
"""
Make an API call to a language model using litellm.
@ -52,6 +77,7 @@ async def make_llm_api_call(
agentops_session (Any, optional): Session for agentops integration
stream (bool, optional): Whether to stream the response. Defaults to False
top_p (float, optional): Top-p sampling parameter
stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens
Returns:
Union[Dict[str, Any], Any]: API response, either complete or streaming
@ -59,7 +85,7 @@ async def make_llm_api_call(
Raises:
Exception: If API call fails after retries
"""
# litellm.set_verbose = True
litellm.set_verbose = False
async def attempt_api_call(api_call_func, max_attempts=3):
"""
@ -75,10 +101,17 @@ async def make_llm_api_call(
Raises:
Exception: If all retry attempts fail
"""
nonlocal model_name # Add this to access model_name
for attempt in range(max_attempts):
try:
return await api_call_func()
except litellm.exceptions.RateLimitError as e:
# Check if it's Bedrock Claude and switch to direct Anthropic
if "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" in model_name:
logging.info("Rate limit hit with Bedrock Claude, falling back to direct Anthropic API...")
model_name = "anthropic/claude-3-5-sonnet-latest"
continue
logging.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
await asyncio.sleep(30)
except OpenAIError as e:
@ -105,6 +138,10 @@ async def make_llm_api_call(
"stream": stream,
}
# Add stop parameter if provided
if stop is not None:
api_call_params["stop"] = stop
# Add optional parameters if provided
if api_key:
api_call_params["api_key"] = api_key
@ -129,6 +166,32 @@ async def make_llm_api_call(
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
}
# Add OpenRouter specific parameters
if "openrouter" in model_name.lower():
if settings.or_site_url:
api_call_params["headers"] = {
"HTTP-Referer": settings.or_site_url
}
if settings.or_app_name:
api_call_params["headers"] = {
"X-Title": settings.or_app_name
}
# Add special handling for Deepseek
if "deepseek" in model_name.lower():
api_call_params["frequency_penalty"] = 0.5
api_call_params["temperature"] = 0.7
api_call_params["presence_penalty"] = 0.1
# Add Bedrock-specific parameters
if "bedrock" in model_name.lower():
if settings.aws_access_key_id:
api_call_params["aws_access_key_id"] = settings.aws_access_key_id
if settings.aws_secret_access_key:
api_call_params["aws_secret_access_key"] = settings.aws_secret_access_key
if settings.aws_region_name:
api_call_params["aws_region_name"] = settings.aws_region_name
# Log the API request
# logging.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
@ -137,10 +200,36 @@ async def make_llm_api_call(
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
else:
response = await litellm.acompletion(**api_call_params)
# Log the API response
# logging.info(f"Received API response: {response}")
# # For streaming responses, attach cost tracking
# if stream:
# # Create a wrapper object to track costs across chunks
# cost_tracker = {
# "prompt_tokens": 0,
# "completion_tokens": 0,
# "total_tokens": 0,
# "cost": 0.0
# }
# # Get the cost per token for the model
# model_cost = litellm.model_cost.get(model_name, {})
# input_cost = model_cost.get('input_cost_per_token', 0)
# output_cost = model_cost.get('output_cost_per_token', 0)
# # Attach the cost tracker to the response
# response.cost_tracker = cost_tracker
# response.model_info = {
# "input_cost_per_token": input_cost,
# "output_cost_per_token": output_cost
# }
# else:
# # For non-streaming, cost is already included in the response
# response._hidden_params = {
# "response_cost": litellm.completion_cost(completion_response=response)
# }
return response
return await attempt_api_call(api_call)
@ -188,4 +277,37 @@ if __name__ == "__main__":
print(response.choices[0].message.content)
print()
asyncio.run(test_llm_api_call())
# asyncio.run(test_llm_api_call())
async def test_bedrock():
"""
Test function for Bedrock API call.
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello from Bedrock!"}
]
model_name = "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"
response = await make_llm_api_call(messages, model_name, stream=True)
print("\n🤖 Streaming response from Bedrock:\n")
buffer = ""
async for chunk in response:
if isinstance(chunk, dict) and 'choices' in chunk:
content = chunk['choices'][0]['delta'].get('content', '')
else:
content = chunk.choices[0].delta.content
if content:
buffer += content
if content[-1].isspace():
print(buffer, end='', flush=True)
buffer = ""
if buffer:
print(buffer, flush=True)
print("\n✨ Stream completed.\n")
# Add test_bedrock to the test runs
# asyncio.run(test_bedrock())

View File

@ -0,0 +1 @@
# Empty file to mark as package

View File

@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
Attributes:
add_message: Callback for adding new messages
update_message: Callback for updating existing messages
get_messages: Callback for retrieving thread messages
get_llm_history_messages: Callback for retrieving thread messages
message_added: Flag tracking if initial message has been added
"""
@ -184,7 +184,7 @@ class ResultsAdderBase(ABC):
"""
self.add_message = thread_manager.add_message
self.update_message = thread_manager._update_message
self.get_messages = thread_manager.get_messages
self.get_llm_history_messages = thread_manager.get_llm_history_messages
self.message_added = False
@abstractmethod

View File

@ -9,12 +9,9 @@ This module provides comprehensive processing of LLM responses, including:
"""
import asyncio
from typing import Callable, Dict, Any, AsyncGenerator, Optional
from typing import Dict, Any, AsyncGenerator
import logging
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
class LLMResponseProcessor:
"""Handles LLM response processing and tool execution management.
@ -37,51 +34,30 @@ class LLMResponseProcessor:
def __init__(
self,
thread_id: str,
available_functions: Dict = None,
add_message_callback: Callable = None,
update_message_callback: Callable = None,
get_messages_callback: Callable = None,
parallel_tool_execution: bool = True,
tool_parser: Optional[ToolParserBase] = None,
tool_executor: Optional[ToolExecutorBase] = None,
results_adder: Optional[ResultsAdderBase] = None,
thread_manager = None
tool_executor: ToolExecutorBase,
tool_parser: ToolParserBase,
available_functions: Dict,
results_adder: ResultsAdderBase
):
"""Initialize the response processor.
Args:
thread_id: ID of the conversation thread
available_functions: Dictionary of available tool functions
add_message_callback: Callback for adding messages
update_message_callback: Callback for updating messages
get_messages_callback: Callback for listing messages
parallel_tool_execution: Whether to execute tools in parallel
tool_parser: Custom tool parser implementation
tool_executor: Custom tool executor implementation
tool_parser: Custom tool parser implementation
available_functions: Dictionary of available tool functions
results_adder: Custom results adder implementation
thread_manager: Optional thread manager instance
"""
self.thread_id = thread_id
self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
self.tool_parser = tool_parser or StandardToolParser()
self.available_functions = available_functions or {}
# Create minimal thread manager if needed
if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
class MinimalThreadManager:
def __init__(self, add_msg, update_msg, list_msg):
self.add_message = add_msg
self._update_message = update_msg
self.get_messages = list_msg
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)
self.results_adder = results_adder or StandardResultsAdder(thread_manager)
# State tracking for streaming
self.tool_calls_buffer = {}
self.processed_tool_calls = set()
self.tool_executor = tool_executor
self.tool_parser = tool_parser
self.available_functions = available_functions
self.results_adder = results_adder
self.content_buffer = ""
self.tool_calls_buffer = {}
self.tool_calls_accumulated = []
self.processed_tool_calls = set()
self._executing_tools = set() # Track currently executing tools
async def process_stream(
self,
@ -92,8 +68,9 @@ class LLMResponseProcessor:
"""Process streaming LLM response and handle tool execution."""
pending_tool_calls = []
background_tasks = set()
stream_completed = False # New flag to track stream completion
async def handle_message_management(chunk):
async def handle_message_management(chunk, is_final=False):
try:
# Accumulate content
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
@ -106,28 +83,37 @@ class LLMResponseProcessor:
self.tool_calls_buffer
)
if parsed_message and 'tool_calls' in parsed_message:
self.tool_calls_accumulated = parsed_message['tool_calls']
new_tool_calls = [
tool_call for tool_call in parsed_message['tool_calls']
if tool_call['id'] not in self.processed_tool_calls
]
if new_tool_calls:
self.tool_calls_accumulated.extend(new_tool_calls)
# Handle tool execution and results
if execute_tools and self.tool_calls_accumulated:
new_tool_calls = [
tool_call for tool_call in self.tool_calls_accumulated
if tool_call['id'] not in self.processed_tool_calls
if (tool_call['id'] not in self.processed_tool_calls and
tool_call['id'] not in self._executing_tools)
]
if new_tool_calls:
if execute_tools_on_stream:
for tool_call in new_tool_calls:
self._executing_tools.add(tool_call['id'])
results = await self.tool_executor.execute_tool_calls(
tool_calls=new_tool_calls,
available_functions=self.available_functions,
thread_id=self.thread_id,
executed_tool_calls=self.processed_tool_calls
)
for result in results:
await self.results_adder.add_tool_result(self.thread_id, result)
self.processed_tool_calls.add(result['tool_call_id'])
else:
pending_tool_calls.extend(new_tool_calls)
self._executing_tools.discard(result['tool_call_id'])
# Add/update assistant message
message = {
@ -152,7 +138,10 @@ class LLMResponseProcessor:
)
# Handle stream completion
if chunk.choices[0].finish_reason:
if chunk.choices[0].finish_reason or is_final:
nonlocal stream_completed
stream_completed = True
if not execute_tools_on_stream and pending_tool_calls:
results = await self.tool_executor.execute_tool_calls(
tool_calls=pending_tool_calls,
@ -165,8 +154,16 @@ class LLMResponseProcessor:
self.processed_tool_calls.add(result['tool_call_id'])
pending_tool_calls.clear()
# Set final state on the chunk instead of returning it
chunk._final_state = {
"content": self.content_buffer,
"tool_calls": self.tool_calls_accumulated,
"processed_tool_calls": list(self.processed_tool_calls)
}
except Exception as e:
logging.error(f"Error in background task: {e}")
raise
try:
async for chunk in response_stream:
@ -175,9 +172,22 @@ class LLMResponseProcessor:
task.add_done_callback(background_tasks.discard)
yield chunk
# Create a final dummy chunk to handle completion
final_chunk = type('DummyChunk', (), {
'choices': [type('DummyChoice', (), {
'delta': type('DummyDelta', (), {'content': None}),
'finish_reason': 'stop'
})]
})()
# Process final state
await handle_message_management(final_chunk, is_final=True)
yield final_chunk
# Wait for all background tasks to complete
if background_tasks:
await asyncio.gather(*background_tasks, return_exceptions=True)
except Exception as e:
logging.error(f"Error in stream processing: {e}")
for task in background_tasks:

View File

@ -0,0 +1 @@
# Empty file to mark as package

View File

@ -81,6 +81,6 @@ class StandardResultsAdder(ResultsAdderBase):
- Checks for duplicate tool results before adding
- Adds result only if tool_call_id is unique
"""
messages = await self.get_messages(thread_id)
messages = await self.get_llm_history_messages(thread_id)
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
await self.add_message(thread_id, result)

View File

@ -0,0 +1 @@
# Empty file to mark as package

View File

@ -79,7 +79,7 @@ class XMLResultsAdder(ResultsAdderBase):
"""
try:
# Get the original tool call to find the root tag
messages = await self.get_messages(thread_id)
messages = await self.get_llm_history_messages(thread_id)
assistant_msg = next((msg for msg in reversed(messages)
if msg['role'] == 'assistant'), None)
@ -107,10 +107,9 @@ class XMLResultsAdder(ResultsAdderBase):
await self.add_message(thread_id, result_message)
except Exception as e:
logging.error(f"Error adding tool result: {e}")
# Ensure the result is still added even if there's an error
logging.error(f"Error adding tool result: {e}") # Ensure the result is still added even if there's an error
result_message = {
"role": "user",
"content": f"Result for {result['name']}:\n{result['content']}"
}
await self.add_message(thread_id, result_message)
await self.add_message(thread_id, result_message)

View File

@ -38,6 +38,8 @@ class XMLToolExecutor(ToolExecutorBase):
"""
self.parallel = parallel
self.tool_registry = tool_registry or ToolRegistry()
# Add internal tracking of executed tools
self._executed_tools = set()
async def execute_tool_calls(
self,
@ -65,20 +67,28 @@ class XMLToolExecutor(ToolExecutorBase):
if executed_tool_calls is None:
executed_tool_calls = set()
# Filter out already executed tool calls
new_tool_calls = [
tool_call for tool_call in tool_calls
if tool_call['id'] not in executed_tool_calls
and tool_call['id'] not in self._executed_tools
]
if not new_tool_calls:
return []
if self.parallel:
return await self._execute_parallel(
tool_calls,
available_functions,
thread_id,
executed_tool_calls
)
results = await self._execute_parallel(new_tool_calls, available_functions, thread_id, executed_tool_calls)
else:
return await self._execute_sequential(
tool_calls,
available_functions,
thread_id,
executed_tool_calls
)
results = await self._execute_sequential(new_tool_calls, available_functions, thread_id, executed_tool_calls)
# Track executed tools internally
for tool_call in new_tool_calls:
self._executed_tools.add(tool_call['id'])
if executed_tool_calls is not None:
executed_tool_calls.add(tool_call['id'])
return results
async def _execute_parallel(
self,
@ -87,9 +97,10 @@ class XMLToolExecutor(ToolExecutorBase):
thread_id: str,
executed_tool_calls: Set[str]
) -> List[Dict[str, Any]]:
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
if tool_call['id'] in executed_tool_calls:
logging.info(f"Tool call {tool_call['id']} already executed")
async def execute_single_tool(tool_call: Dict[str, Any]) -> Optional[Dict[str, Any]]:
# Double-check the tool hasn't been executed
if (tool_call['id'] in executed_tool_calls or
tool_call['id'] in self._executed_tools):
return None
try:

View File

@ -1,207 +1,118 @@
"""
Manages persistent state storage for AgentPress components using thread-based events.
The StateManager provides thread-safe access to state data stored as events in threads,
allowing components to save and retrieve data across sessions. Each state update
creates a new event containing the complete state.
"""
import json
import logging
from typing import Any, Optional, List, Dict
from asyncio import Lock
import uuid
from agentpress.thread_manager import ThreadManager
class StateManager:
"""
Manages persistent state storage for AgentPress components using thread events.
The StateManager provides thread-safe access to state data stored as events,
maintaining the complete state in each event for better consistency and tracking.
Attributes:
lock (Lock): Asyncio lock for thread-safe state access
thread_id (str): Thread ID for state storage
thread_manager (ThreadManager): Thread manager instance for event handling
Manages state storage using thread messages.
Each state message contains a complete snapshot of the state at that point in time.
"""
def __init__(self, thread_id: str):
"""
Initialize StateManager with thread ID.
Args:
thread_id (str): Thread ID for state storage
"""
self.lock = Lock()
self.thread_id = thread_id
"""Initialize StateManager with a thread ID."""
self.thread_manager = ThreadManager()
logging.info(f"StateManager initialized with thread_id: {self.thread_id}")
self.thread_id = thread_id
self._state_cache = None
logging.info(f"StateManager initialized for thread: {thread_id}")
async def initialize(self):
"""Initialize the thread manager."""
await self.thread_manager.initialize()
async def _get_state(self) -> Dict[str, Any]:
"""Get the current complete state."""
if self._state_cache is not None:
return self._state_cache.copy() # Return copy to prevent cache mutation
async def _ensure_initialized(self):
"""Ensure thread manager is initialized."""
if not self.thread_manager.db._initialized:
await self.initialize()
async def _get_current_state(self) -> dict:
"""Get the current state from the latest state event."""
await self._ensure_initialized()
events = await self.thread_manager.get_thread_events(
thread_id=self.thread_id,
event_types=["state"],
order_by="created_at",
order="DESC"
# Get the latest state message
rows = await self.thread_manager.db.fetch_all(
"""
SELECT content
FROM messages
WHERE thread_id = ? AND type = 'state_message'
ORDER BY created_at DESC LIMIT 1
""",
(self.thread_id,)
)
if events:
return events[0]["content"].get("state", {})
if rows:
try:
self._state_cache = json.loads(rows[0][0])
return self._state_cache.copy()
except json.JSONDecodeError:
logging.error("Failed to parse state JSON")
return {}
async def _save_state(self, state: dict):
"""Save the complete state as a new event."""
await self._ensure_initialized()
await self.thread_manager.create_event(
async def _save_state(self, state: Dict[str, Any]):
"""Save a new complete state snapshot."""
# Format state as a string with proper indentation
formatted_state = json.dumps(state, indent=2)
# Save new state message with complete snapshot
await self.thread_manager.add_message(
thread_id=self.thread_id,
event_type="state",
content={"state": state}
message_data=formatted_state,
message_type='state_message',
include_in_llm_message_history=False
)
# Update cache with a copy
self._state_cache = state.copy()
async def set(self, key: str, data: Any) -> Any:
"""
Store data with a key in the state.
Args:
key (str): Simple string key like "config" or "settings"
data (Any): Any JSON-serializable data
Returns:
Any: The stored data
"""
async with self.lock:
try:
current_state = await self._get_current_state()
current_state[key] = data
await self._save_state(current_state)
logging.info(f'Updated state key: {key}')
return data
except Exception as e:
logging.error(f"Error setting state: {e}")
raise
"""Store any JSON-serializable data with a key."""
state = await self._get_state()
state[key] = data
await self._save_state(state)
logging.info(f'Updated state key: {key}')
return data
async def get(self, key: str) -> Optional[Any]:
"""
Get data for a key from the current state.
Args:
key (str): Simple string key like "config" or "settings"
Returns:
Any: The stored data for the key, or None if key not found
"""
async with self.lock:
try:
current_state = await self._get_current_state()
if key in current_state:
logging.info(f'Retrieved key: {key}')
return current_state[key]
logging.info(f'Key not found: {key}')
return None
except Exception as e:
logging.error(f"Error getting state: {e}")
raise
"""Get data for a key."""
state = await self._get_state()
if key in state:
data = state[key]
logging.info(f'Retrieved key: {key}')
return data
logging.info(f'Key not found: {key}')
return None
async def delete(self, key: str):
"""
Delete a key from the state.
Args:
key (str): Simple string key like "config" or "settings"
"""
async with self.lock:
try:
current_state = await self._get_current_state()
if key in current_state:
del current_state[key]
await self._save_state(current_state)
logging.info(f"Deleted key: {key}")
except Exception as e:
logging.error(f"Error deleting state: {e}")
raise
"""Delete data for a key."""
state = await self._get_state()
if key in state:
del state[key]
await self._save_state(state)
logging.info(f"Deleted key: {key}")
async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]:
"""
Update existing dictionary data for a key by merging.
Args:
key (str): Simple string key like "config" or "settings"
data (Dict[str, Any]): Dictionary of updates to merge
Returns:
Optional[Any]: Updated data if successful, None if key not found
"""
async with self.lock:
try:
current_state = await self._get_current_state()
if key in current_state and isinstance(current_state[key], dict):
current_state[key].update(data)
await self._save_state(current_state)
return current_state[key]
return None
except Exception as e:
logging.error(f"Error updating state: {e}")
raise
"""Update existing data for a key by merging dictionaries."""
state = await self._get_state()
if key in state and isinstance(state[key], dict):
state[key].update(data)
await self._save_state(state)
logging.info(f'Updated state key: {key}')
return state[key]
return None
async def append(self, key: str, item: Any) -> Optional[List[Any]]:
"""
Append an item to a list stored at key.
Args:
key (str): Simple string key like "config" or "settings"
item (Any): Item to append
Returns:
Optional[List[Any]]: Updated list if successful, None if key not found
"""
async with self.lock:
try:
current_state = await self._get_current_state()
if key not in current_state:
current_state[key] = []
if isinstance(current_state[key], list):
current_state[key].append(item)
await self._save_state(current_state)
return current_state[key]
return None
except Exception as e:
logging.error(f"Error appending to state: {e}")
raise
"""Append an item to a list stored at key."""
state = await self._get_state()
if key not in state:
state[key] = []
if isinstance(state[key], list):
state[key].append(item)
await self._save_state(state)
logging.info(f'Appended to key: {key}')
return state[key]
return None
async def get_latest_state(self) -> dict:
"""
Get the latest complete state.
Returns:
dict: Complete contents of the latest state
"""
async with self.lock:
try:
state = await self._get_current_state()
logging.info(f"Retrieved latest state with {len(state)} keys")
return state
except Exception as e:
logging.error(f"Error getting latest state: {e}")
raise
async def export_store(self) -> dict:
"""Export entire state."""
state = await self._get_state()
return state
async def clear_state(self):
"""
Clear the entire state.
"""
async with self.lock:
try:
await self._save_state({})
logging.info("Cleared state")
except Exception as e:
logging.error(f"Error clearing state: {e}")
raise
async def clear_store(self):
"""Clear entire state."""
await self._save_state({})
self._state_cache = {}
logging.info("Cleared state")

File diff suppressed because it is too large Load Diff

View File

@ -1,64 +1,111 @@
import streamlit as st
from datetime import datetime
from agentpress.thread_manager import ThreadManager
from agentpress.db_connection import DBConnection
import asyncio
import json
def format_message_content(content):
"""Format message content handling both string and list formats."""
if isinstance(content, str):
return content
elif isinstance(content, list):
formatted_content = []
for item in content:
if item.get('type') == 'text':
formatted_content.append(item['text'])
elif item.get('type') == 'image_url':
formatted_content.append("[Image]")
return "\n".join(formatted_content)
return str(content)
"""Format message content handling various formats."""
try:
if isinstance(content, str):
# Try to parse JSON strings
try:
parsed = json.loads(content)
if isinstance(parsed, (dict, list)):
return json.dumps(parsed, indent=2)
except json.JSONDecodeError:
return content
elif isinstance(content, list):
formatted_content = []
for item in content:
if item.get('type') == 'text':
formatted_content.append(item['text'])
elif item.get('type') == 'image_url':
formatted_content.append("[Image]")
return "\n".join(formatted_content)
return json.dumps(content, indent=2)
except:
return str(content)
async def load_threads():
"""Load all thread IDs from the database."""
db = DBConnection()
rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC")
rows = await db.fetch_all(
"""
SELECT id, created_at
FROM threads
ORDER BY created_at DESC
"""
)
return rows
async def load_thread_content(thread_id: str):
"""Load the content of a specific thread from the database."""
thread_manager = ThreadManager()
return await thread_manager.get_messages(thread_id)
async def load_thread_content(thread_id: str, filters: dict):
"""Load messages from a thread with filters."""
db = DBConnection()
query_parts = ["SELECT type, content, include_in_llm_message_history, created_at FROM messages WHERE thread_id = ?"]
params = [thread_id]
if filters.get('message_types'):
# Convert comma-separated string to list and clean up whitespace
types_list = [t.strip() for t in filters['message_types'].split(',') if t.strip()]
if types_list:
query_parts.append("AND type IN (" + ",".join(["?" for _ in types_list]) + ")")
params.extend(types_list)
if filters.get('exclude_message_types'):
# Convert comma-separated string to list and clean up whitespace
exclude_types_list = [t.strip() for t in filters['exclude_message_types'].split(',') if t.strip()]
if exclude_types_list:
query_parts.append("AND type NOT IN (" + ",".join(["?" for _ in exclude_types_list]) + ")")
params.extend(exclude_types_list)
if filters.get('before_timestamp'):
query_parts.append("AND created_at < ?")
params.append(filters['before_timestamp'])
if filters.get('after_timestamp'):
query_parts.append("AND created_at > ?")
params.append(filters['after_timestamp'])
if filters.get('include_in_llm_message_history') is not None:
query_parts.append("AND include_in_llm_message_history = ?")
params.append(filters['include_in_llm_message_history'])
# Add ordering
order_direction = "DESC" if filters.get('order', 'asc').lower() == 'desc' else "ASC"
query_parts.append(f"ORDER BY created_at {order_direction}")
# Add limit and offset
if filters.get('limit'):
query_parts.append("LIMIT ?")
params.append(filters['limit'])
if filters.get('offset'):
query_parts.append("OFFSET ?")
params.append(filters['offset'])
query = " ".join(query_parts)
rows = await db.fetch_all(query, tuple(params))
return rows
def render_message(role, content, avatar):
"""Render a message with a consistent chat-like style."""
# Create columns for avatar and message
col1, col2 = st.columns([1, 11])
# Style based on role
if role == "assistant":
bgcolor = "rgba(25, 25, 25, 0.05)"
elif role == "user":
bgcolor = "rgba(25, 120, 180, 0.05)"
elif role == "system":
bgcolor = "rgba(180, 25, 25, 0.05)"
else:
bgcolor = "rgba(100, 100, 100, 0.05)"
# Display avatar in first column
def render_message(msg_type: str, content: str, include_in_llm: bool, timestamp: str):
"""Render a message using Streamlit components."""
# Message type and metadata
col1, col2 = st.columns([3, 1])
with col1:
st.markdown(f"<div style='text-align: center; font-size: 24px;'>{avatar}</div>", unsafe_allow_html=True)
# Display message in second column
st.text(f"Type: {msg_type}")
with col2:
st.markdown(
f"""
<div style='background-color: {bgcolor}; padding: 10px; border-radius: 5px;'>
<strong>{role.upper()}</strong><br>
{content}
</div>
""",
unsafe_allow_html=True
)
st.text("🟢 LLM" if include_in_llm else "⚫ Non-LLM")
# Timestamp
st.text(f"Time: {datetime.fromisoformat(timestamp).strftime('%Y-%m-%d %H:%M:%S')}")
# Message content
st.code(content, language="json")
# Separator
st.divider()
def main():
st.title("Thread Viewer")
@ -86,7 +133,6 @@ def main():
)
if selected_thread_display:
# Get the actual thread ID from the display string
selected_thread_id = thread_options[selected_thread_display]
# Display thread ID in sidebar
@ -95,46 +141,77 @@ def main():
# Add refresh button
if st.sidebar.button("🔄 Refresh Thread"):
st.session_state.threads = asyncio.run(load_threads())
st.experimental_rerun()
st.rerun()
# Load and display messages
messages = asyncio.run(load_thread_content(selected_thread_id))
# Advanced filtering options in sidebar
st.sidebar.title("Filter Options")
# Display messages in chat-like interface
for message in messages:
role = message.get("role", "unknown")
content = message.get("content", "")
# Determine avatar based on role
if role == "assistant":
avatar = "🤖"
elif role == "user":
avatar = "👤"
elif role == "system":
avatar = "⚙️"
elif role == "tool":
avatar = "🔧"
else:
avatar = ""
# Format the content
# Message type filters
col1, col2 = st.sidebar.columns(2)
with col1:
message_types = st.text_input(
"Include Types",
help="Enter message types to include, separated by commas"
)
with col2:
exclude_message_types = st.text_input(
"Exclude Types",
help="Enter message types to exclude, separated by commas"
)
# Limit and offset
col1, col2 = st.sidebar.columns(2)
with col1:
limit = st.number_input("Limit", min_value=1, value=50)
with col2:
offset = st.number_input("Offset", min_value=0, value=0)
# Timestamp filters
st.sidebar.subheader("Time Range")
before_timestamp = st.sidebar.date_input("Before Date", value=None)
after_timestamp = st.sidebar.date_input("After Date", value=None)
# LLM history filter
include_in_llm = st.sidebar.radio(
"LLM History Filter",
options=["All Messages", "LLM Only", "Non-LLM Only"]
)
# Sort order
order = st.sidebar.radio("Sort Order", ["Ascending", "Descending"])
# Prepare filters
filters = {
'message_types': message_types if message_types else None,
'exclude_message_types': exclude_message_types if exclude_message_types else None,
'limit': limit,
'offset': offset,
'order': 'desc' if order == "Descending" else 'asc'
}
# Add timestamp filters if selected
if before_timestamp:
filters['before_timestamp'] = before_timestamp.isoformat()
if after_timestamp:
filters['after_timestamp'] = after_timestamp.isoformat()
# Add LLM history filter
if include_in_llm == "LLM Only":
filters['include_in_llm_message_history'] = True
elif include_in_llm == "Non-LLM Only":
filters['include_in_llm_message_history'] = False
# Load messages with filters
messages = asyncio.run(load_thread_content(selected_thread_id, filters))
if not messages:
st.info("No messages found with current filters")
return
# Display messages
for msg_type, content, include_in_llm, timestamp in messages:
formatted_content = format_message_content(content)
# Render the message
render_message(role, formatted_content, avatar)
# Display tool calls if present
if "tool_calls" in message:
with st.expander("🛠️ Tool Calls"):
for tool_call in message["tool_calls"]:
st.code(
f"Function: {tool_call['function']['name']}\n"
f"Arguments: {tool_call['function']['arguments']}",
language="json"
)
# Add some spacing between messages
st.markdown("<br>", unsafe_allow_html=True)
render_message(msg_type, formatted_content, include_in_llm, timestamp)
if __name__ == "__main__":
main()

1
example/__init__.py Normal file
View File

@ -0,0 +1 @@
# Empty file to mark as package

261
example/agent.py Normal file
View File

@ -0,0 +1,261 @@
"""
Interactive web development agent supporting both XML and Standard LLM tool calling.
This agent can:
- Create and modify web projects
- Execute terminal commands
- Handle file operations
- Use either XML or Standard tool calling patterns
"""
import asyncio
import json
from agentpress.thread_manager import ThreadManager
from example.tools.files_tool import FilesTool
from agentpress.state_manager import StateManager
from example.tools.terminal_tool import TerminalTool
import logging
from typing import AsyncGenerator, Optional, Dict, Any
import sys
from agentpress.api.api_factory import register_thread_task_api
BASE_SYSTEM_MESSAGE = """
You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
codebase - do not omit previous progress; instead, keep iterating.
Available tools:
- create_file: Create new files with specified content
- delete_file: Remove existing files
- str_replace: Make precise text replacements in files
- execute_command: Run terminal commands
RULES:
- All current file contents are available to you in the <current_workspace_state> section
- Each file in the workspace state includes its full content
- Use str_replace for precise replacements in files
- NEVER include comments in any code you write - the code should be self-documenting
- Always maintain the full context of files when making changes
- When creating new files, write clean code without any comments or documentation
<available_tools>
[create_file(file_path, file_contents)] - Create new files
[delete_file(file_path)] - Delete existing files
[str_replace(file_path, old_str, new_str)] - Replace specific text in files
[execute_command(command)] - Execute terminal commands
</available_tools>
ALWAYS RESPOND WITH MULTIPLE SIMULTANEOUS ACTIONS:
<thoughts>
[Provide a concise overview of your planned changes and implementations]
</thoughts>
<actions>
[Include multiple tool calls]
</actions>
EDITING GUIDELINES:
1. Review the current file contents in the workspace state
2. Make targeted changes with str_replace
3. Write clean, self-documenting code without comments
4. Use create_file for new files and str_replace for modifications
Example workspace state for a file:
{
"index.html": {
"content": "<!DOCTYPE html>\\n<html>\\n<head>..."
}
}
Think deeply and step by step.
"""
XML_FORMAT = """
RESPONSE FORMAT:
Use XML tags to specify file operations:
<create-file file_path="path/to/file">
file contents here
</create-file>
<str-replace file_path="path/to/file">
<old_str>text to replace</old_str>
<new_str>replacement text</new_str>
</str-replace>
<delete-file file_path="path/to/file">
</delete-file>
<stop_session></stop_session>
"""
@register_thread_task_api("/agent")
async def run_agent(
thread_id: str,
max_iterations: int = 5,
user_input: Optional[str] = None,
) -> Dict[str, Any]:
"""Run the development agent with specified configuration.
Args:
thread_id (str): The ID of the thread.
max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
user_input (Optional[str], optional): The user input. Defaults to None.
"""
thread_manager = ThreadManager()
state_manager = StateManager(thread_id)
if user_input:
await thread_manager.add_message(
thread_id,
{
"role": "user",
"content": user_input
}
)
thread_manager.add_tool(FilesTool, thread_id=thread_id)
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
system_message = {
"role": "system",
"content": BASE_SYSTEM_MESSAGE + XML_FORMAT
}
iteration = 0
while iteration < max_iterations:
iteration += 1
files_tool = FilesTool(thread_id=thread_id)
state = await state_manager.export_store()
temporary_message_content = f"""
You are tasked to complete the LATEST USER REQUEST!
<latest_user_request>
{user_input}
</latest_user_request>
Current development environment workspace state:
<current_workspace_state>
{json.dumps(state, indent=2) if state else "{}"}
</current_workspace_state>
CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
"""
await thread_manager.add_message(
thread_id=thread_id,
message_data=temporary_message_content,
message_type="temporary_message",
include_in_llm_message_history=False
)
temporary_message = {
"role": "user",
"content": temporary_message_content
}
model_name = "anthropic/claude-3-5-sonnet-latest"
response = await thread_manager.run_thread(
thread_id=thread_id,
system_message=system_message,
model_name=model_name,
temperature=0.1,
max_tokens=8096,
tool_choice="auto",
temporary_message=temporary_message,
native_tool_calling=False,
xml_tool_calling=True,
stream=True,
execute_tools_on_stream=True,
parallel_tool_execution=True,
)
if isinstance(response, AsyncGenerator):
print("\n🤖 Assistant is responding:")
try:
async for chunk in response:
if hasattr(chunk.choices[0], 'delta'):
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content is not None:
content = delta.content
print(content, end='', flush=True)
# Check for open_files_in_editor tag and continue if found
if '</open_files_in_editor>' in content:
print("\n📂 Opening files in editor, continuing to next iteration...")
continue
if hasattr(delta, 'tool_calls') and delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.function:
if tool_call.function.name:
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
if tool_call.function.arguments:
print(f" {tool_call.function.arguments}", end='', flush=True)
print("\n✨ Response completed\n")
except Exception as e:
print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
logging.error(f"Error processing stream: {e}")
else:
print("\nNon-streaming response received:", response)
# # Get latest assistant message and check for stop_session
# latest_msg = await thread_manager.get_llm_history_messages(
# thread_id=thread_id,
# only_latest_assistant=True
# )
# if latest_msg and '</stop_session>' in latest_msg:
# break
if __name__ == "__main__":
print("\n🚀 Welcome to AgentPress!")
project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
if not project_description.strip():
project_description = "Create a modern, responsive landing page"
print("\nChoose your agent type:")
print("1. XML-based Tool Calling")
print(" - Structured XML format for tool execution")
print(" - Parses tool calls using XML outputs in the LLM response")
print("\n2. Standard Function Calling")
print(" - Native LLM function calling format")
print(" - JSON-based parameter passing")
use_xml = input("\nSelect tool calling format [1/2] (default: 1): ").strip() != "2"
print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
print("Use Ctrl+C to stop the agent at any time.")
async def test_agent():
thread_manager = ThreadManager()
thread_id = await thread_manager.create_thread()
logging.info(f"Created new thread: {thread_id}")
try:
result = await run_agent(
thread_id=thread_id,
max_iterations=5,
user_input=project_description,
)
print("\n✅ Test completed successfully!")
except Exception as e:
print(f"\n❌ Test failed: {str(e)}")
raise
try:
asyncio.run(test_agent())
except KeyboardInterrupt:
print("\n⚠️ Test interrupted by user")
except Exception as e:
print(f"\n❌ Test failed with error: {str(e)}")
raise

297
example/tools/files_tool.py Normal file
View File

@ -0,0 +1,297 @@
import os
import asyncio
from pathlib import Path
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional
class FilesTool(Tool):
"""File management tool for creating, updating, and deleting files.
This tool provides file operations within a workspace directory, with built-in
file filtering and state tracking capabilities.
Attributes:
workspace (str): Path to the workspace directory
EXCLUDED_FILES (set): Files to exclude from operations
EXCLUDED_DIRS (set): Directories to exclude
EXCLUDED_EXT (set): File extensions to exclude
SNIPPET_LINES (int): Context lines for edit previews
"""
# Excluded files, directories, and extensions
EXCLUDED_FILES = {
".DS_Store",
".gitignore",
"package-lock.json",
"postcss.config.js",
"postcss.config.mjs",
"jsconfig.json",
"components.json",
"tsconfig.tsbuildinfo",
"tsconfig.json",
}
EXCLUDED_DIRS = {
"node_modules",
".next",
"dist",
"build",
".git"
}
EXCLUDED_EXT = {
".ico",
".svg",
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".tiff",
".webp",
".db",
".sql"
}
def __init__(self, thread_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
if thread_id:
self.state_manager = StateManager(thread_id)
asyncio.create_task(self._init_workspace_state())
self.SNIPPET_LINES = 4
def _should_exclude_file(self, rel_path: str) -> bool:
"""Check if a file should be excluded based on path, name, or extension"""
# Check filename
filename = os.path.basename(rel_path)
if filename in self.EXCLUDED_FILES:
return True
# Check directory
dir_path = os.path.dirname(rel_path)
if any(excluded in dir_path for excluded in self.EXCLUDED_DIRS):
return True
# Check extension
_, ext = os.path.splitext(filename)
if ext.lower() in self.EXCLUDED_EXT:
return True
return False
async def _init_workspace_state(self):
"""Initialize or update the workspace state in JSON"""
files_state = {}
# Walk through workspace and record all files
for root, _, files in os.walk(self.workspace):
for file in files:
full_path = os.path.join(root, file)
rel_path = os.path.relpath(full_path, self.workspace)
# Skip excluded files
if self._should_exclude_file(rel_path):
continue
try:
with open(full_path, 'r') as f:
content = f.read()
files_state[rel_path] = {
"content": content
}
except Exception as e:
print(f"Error reading file {rel_path}: {e}")
except UnicodeDecodeError:
print(f"Skipping binary file: {rel_path}")
if hasattr(self, 'state_manager'):
await self.state_manager.set("files", files_state)
async def _update_workspace_state(self):
"""Update the workspace state after any file operation"""
await self._init_workspace_state()
@openapi_schema({
"type": "function",
"function": {
"name": "create_file",
"description": "Create a new file with the provided contents at a given path in the workspace",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file to be created"
},
"file_contents": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["file_path", "file_contents"]
}
}
})
@xml_schema(
tag_name="create-file",
mappings=[
{"param_name": "file_path", "node_type": "attribute", "path": "."},
{"param_name": "file_contents", "node_type": "content", "path": "."}
],
example='''
<create-file file_path="path/to/file">
File contents go here
</create-file>
'''
)
async def create_file(self, file_path: str, file_contents: str) -> ToolResult:
try:
full_path = os.path.join(self.workspace, file_path)
if os.path.exists(full_path):
return self.fail_response(f"File '{file_path}' already exists. Use update_file to modify existing files.")
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, 'w') as f:
f.write(file_contents)
await self._update_workspace_state()
return self.success_response(f"File '{file_path}' created successfully.")
except Exception as e:
return self.fail_response(f"Error creating file: {str(e)}")
@openapi_schema({
"type": "function",
"function": {
"name": "delete_file",
"description": "Delete a file at the given path",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file to be deleted"
}
},
"required": ["file_path"]
}
}
})
@xml_schema(
tag_name="delete-file",
mappings=[
{"param_name": "file_path", "node_type": "attribute", "path": "."}
],
example='''
<delete-file file_path="path/to/file">
</delete-file>
'''
)
async def delete_file(self, file_path: str) -> ToolResult:
try:
full_path = os.path.join(self.workspace, file_path)
os.remove(full_path)
await self._update_workspace_state()
return self.success_response(f"File '{file_path}' deleted successfully.")
except Exception as e:
return self.fail_response(f"Error deleting file: {str(e)}")
@openapi_schema({
"type": "function",
"function": {
"name": "str_replace",
"description": "Replace text in file",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the target file"
},
"old_str": {
"type": "string",
"description": "Text to be replaced (must appear exactly once)"
},
"new_str": {
"type": "string",
"description": "Replacement text"
}
},
"required": ["file_path", "old_str", "new_str"]
}
}
})
@xml_schema(
tag_name="str-replace",
mappings=[
{"param_name": "file_path", "node_type": "attribute", "path": "file_path"},
{"param_name": "old_str", "node_type": "element", "path": "old_str"},
{"param_name": "new_str", "node_type": "element", "path": "new_str"}
],
example='''
<str-replace file_path="path/to/file">
<old_str>text to replace</old_str>
<new_str>replacement text</new_str>
</str-replace>
'''
)
async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolResult:
try:
full_path = Path(os.path.join(self.workspace, file_path))
if not full_path.exists():
return self.fail_response(f"File '{file_path}' does not exist")
content = full_path.read_text().expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs()
occurrences = content.count(old_str)
if occurrences == 0:
return self.fail_response(f"String '{old_str}' not found in file")
if occurrences > 1:
lines = [i+1 for i, line in enumerate(content.split('\n')) if old_str in line]
return self.fail_response(f"Multiple occurrences found in lines {lines}. Please ensure string is unique")
# Perform replacement
new_content = content.replace(old_str, new_str)
full_path.write_text(new_content)
# Update state after file modification
await self._update_workspace_state()
# Show snippet around the edit
replacement_line = content.split(old_str)[0].count('\n')
start_line = max(0, replacement_line - self.SNIPPET_LINES)
end_line = replacement_line + self.SNIPPET_LINES + new_str.count('\n')
snippet = '\n'.join(new_content.split('\n')[start_line:end_line + 1])
return self.success_response(f"Replacement successful. Snippet of changes:\n{snippet}")
except Exception as e:
return self.fail_response(f"Error replacing string: {str(e)}")
if __name__ == "__main__":
async def test_files_tool():
files_tool = FilesTool()
test_file_path = "test_file.txt"
test_content = "This is a test file."
updated_content = "This is an updated test file."
print(f"Using workspace directory: {files_tool.workspace}")
# Test create_file
create_result = await files_tool.create_file(test_file_path, test_content)
print("Create file result:", create_result)
# Test delete_file
delete_result = await files_tool.delete_file(test_file_path)
print("Delete file result:", delete_result)
# Test read_file after delete (should fail)
read_deleted_result = await files_tool.read_file(test_file_path)
print("Read deleted file result:", read_deleted_result)
asyncio.run(test_files_tool())

View File

@ -0,0 +1,73 @@
import os
import asyncio
import subprocess
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from typing import Optional
class TerminalTool(Tool):
"""Terminal command execution tool for workspace operations."""
def __init__(self, thread_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
@openapi_schema({
"type": "function",
"function": {
"name": "execute_command",
"description": "Execute a shell command in the workspace directory",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
}
}
})
@xml_schema(
tag_name="execute-command",
mappings=[
{"param_name": "command", "node_type": "content", "path": "."}
],
example='''
<execute-command>
npm install package-name
</execute-command>
'''
)
async def execute_command(self, command: str) -> ToolResult:
original_dir = os.getcwd()
try:
os.chdir(self.workspace)
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.workspace
)
stdout, stderr = await process.communicate()
output = stdout.decode() if stdout else ""
error = stderr.decode() if stderr else ""
success = process.returncode == 0
if success:
return self.success_response({
"output": output,
"error": error,
"exit_code": process.returncode,
"cwd": self.workspace
})
else:
return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
except Exception as e:
return self.fail_response(f"Error executing command: {str(e)}")
finally:
os.chdir(original_dir)

View File

@ -0,0 +1,373 @@
:root {
--primary-color: #2563eb;
--secondary-color: #1e40af;
--text-color: #1f2937;
--light-text: #6b7280;
--background: #ffffff;
--section-bg: #f3f4f6;
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
html {
scroll-behavior: smooth;
}
.scroll-top {
position: fixed;
bottom: 2rem;
right: 2rem;
background: var(--primary-color);
color: white;
width: 45px;
height: 45px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
opacity: 0;
visibility: hidden;
transition: all 0.3s ease;
z-index: 1000;
}
.scroll-top.visible {
opacity: 1;
visibility: visible;
}
.scroll-top:hover {
transform: translateY(-3px);
background: var(--secondary-color);
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
line-height: 1.6;
color: var(--text-color);
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 0 2rem;
}
.header {
position: fixed;
width: 100%;
background: var(--background);
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
z-index: 1000;
transition: transform 0.3s ease;
}
.header.scrolled {
transform: translateY(-100%);
}
.header:hover {
transform: translateY(0);
}
.nav {
display: flex;
justify-content: space-between;
align-items: center;
padding: 1rem 2rem;
}
.logo {
font-size: 1.5rem;
font-weight: bold;
color: var(--primary-color);
}
.nav-links {
display: flex;
gap: 2rem;
list-style: none;
}
.nav-links a {
text-decoration: none;
color: var(--text-color);
font-weight: 500;
transition: color 0.3s ease;
}
.nav-links a:hover {
color: var(--primary-color);
}
.mobile-nav-toggle {
display: none;
}
.hero {
padding: 8rem 0 4rem;
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
background-size: 200% 200%;
color: white;
text-align: center;
animation: gradientMove 10s ease infinite;
}
@keyframes gradientMove {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.hero h1 {
font-size: 3rem;
margin-bottom: 1rem;
}
.hero p {
font-size: 1.25rem;
margin-bottom: 2rem;
}
.cta-button {
padding: 1rem 2rem;
font-size: 1.1rem;
background: white;
color: var(--primary-color);
border: none;
border-radius: 5px;
cursor: pointer;
transition: transform 0.3s ease;
}
.cta-button:hover {
transform: translateY(-2px);
}
.features {
padding: 4rem 0;
background: var(--section-bg);
}
.features h2 {
text-align: center;
margin-bottom: 3rem;
}
.feature-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 2rem;
}
.feature-card {
background: white;
padding: 2rem;
border-radius: 10px;
text-align: center;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
opacity: 0;
transform: translateY(20px);
animation: fadeInUp 0.6s ease forwards;
}
@keyframes fadeInUp {
to {
opacity: 1;
transform: translateY(0);
}
}
.feature-card:hover {
transform: translateY(-5px);
}
.feature-icon {
font-size: 2.5rem;
margin-bottom: 1rem;
}
.contact {
padding: 4rem 0;
}
.contact h2 {
text-align: center;
margin-bottom: 2rem;
}
.contact-form {
max-width: 600px;
margin: 0 auto;
display: flex;
flex-direction: column;
gap: 1rem;
}
.contact-form input,
.contact-form textarea {
padding: 0.8rem;
border: 2px solid #ddd;
border-radius: 5px;
font-size: 1rem;
transition: all 0.3s ease;
outline: none;
}
.contact-form input:focus,
.contact-form textarea:focus {
border-color: var(--primary-color);
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
}
.contact-form input:hover,
.contact-form textarea:hover {
border-color: var(--primary-color);
}
.contact-form textarea {
height: 150px;
resize: vertical;
}
.submit-button {
padding: 1rem;
background: var(--primary-color);
color: white;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.submit-button:hover {
background: var(--secondary-color);
}
.testimonials {
padding: 4rem 0;
background: var(--section-bg);
}
.testimonials h2 {
text-align: center;
margin-bottom: 3rem;
}
.testimonial-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 2rem;
max-width: 1200px;
margin: 0 auto;
padding: 0 2rem;
}
.testimonial-card {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.testimonial-text {
font-style: italic;
margin-bottom: 1rem;
}
.testimonial-author {
font-weight: bold;
color: var(--primary-color);
}
.footer {
background: var(--text-color);
color: white;
padding: 2rem 0;
text-align: center;
}
.social-links {
display: flex;
justify-content: center;
gap: 1.5rem;
margin: 1rem 0;
}
.social-links a {
color: white;
text-decoration: none;
font-size: 1.5rem;
transition: color 0.3s ease;
}
.social-links a:hover {
color: var(--primary-color);
}
@media (max-width: 768px) {
.nav-links {
display: none;
position: fixed;
top: 70px;
left: 0;
right: 0;
background: var(--background);
flex-direction: column;
padding: 2rem;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
text-align: center;
transform: translateY(-100%);
opacity: 0;
transition: transform 0.3s ease, opacity 0.3s ease;
}
.nav-links.active {
display: flex;
transform: translateY(0);
opacity: 1;
}
.mobile-nav-toggle {
display: block;
background: none;
border: none;
cursor: pointer;
}
.mobile-nav-toggle span {
display: block;
width: 25px;
height: 3px;
background: var(--text-color);
margin: 5px 0;
transition: 0.3s;
position: relative;
}
.mobile-nav-toggle.active span:nth-child(1) {
transform: rotate(45deg) translate(5px, 5px);
}
.mobile-nav-toggle.active span:nth-child(2) {
opacity: 0;
}
.mobile-nav-toggle.active span:nth-child(3) {
transform: rotate(-45deg) translate(7px, -6px);
}
.hero h1 {
font-size: 2.5rem;
}
.feature-grid {
grid-template-columns: 1fr;
}
}

View File

@ -0,0 +1,113 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Modern Landing Page</title>
<link rel="stylesheet" href="css/styles.css">
<style>
.loading {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: white;
display: flex;
justify-content: center;
align-items: center;
z-index: 9999;
}
.loading::after {
content: '';
width: 40px;
height: 40px;
border: 4px solid #ddd;
border-top-color: var(--primary-color);
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.loaded {
opacity: 0;
visibility: hidden;
transition: all 0.5s ease;
}
</style>
</head>
<body>
<div class="loading"></div>
<header class="header">
<nav class="nav">
<div class="logo">Brand</div>
<ul class="nav-links">
<li><a href="#home">Home</a></li>
<li><a href="#features">Features</a></li>
<li><a href="#about">About</a></li>
<li><a href="#contact">Contact</a></li>
</ul>
<button class="mobile-nav-toggle" aria-label="Toggle menu">
<span></span>
<span></span>
<span></span>
</button>
</nav>
</header>
<main>
<section id="hero" class="hero">
<div class="container">
<h1>Welcome to the Future</h1>
<p>Experience innovation like never before</p>
<button class="cta-button">Get Started</button>
</div>
</section>
<section id="features" class="features">
<div class="container">
<h2>Our Features</h2>
<div class="feature-grid">
<div class="feature-card">
<span class="feature-icon">🚀</span>
<h3>Fast Performance</h3>
<p>Lightning-quick loading times</p>
</div>
<div class="feature-card">
<span class="feature-icon">🎨</span>
<h3>Beautiful Design</h3>
<p>Stunning visual experience</p>
</div>
<div class="feature-card">
<span class="feature-icon">📱</span>
<h3>Responsive</h3>
<p>Works on all devices</p>
</div>
</div>
</div>
</section>
<section id="contact" class="contact">
<div class="container">
<h2>Get in Touch</h2>
<form class="contact-form">
<input type="text" name="name" placeholder="Name" required>
<input type="email" name="email" placeholder="Email" required>
<textarea name="message" placeholder="Message" required></textarea>
<button type="submit" class="submit-button">Send Message</button>
</form>
</div>
</section>
</main>
<footer class="footer">
<div class="container">
<p>&copy; 2024 Brand. All rights reserved.</p>
</div>
</footer>
<button class="scroll-top" aria-label="Scroll to top"></button>
<script src="js/main.js"></script>
</body>
</html>

View File

@ -0,0 +1,85 @@
const mobileNavToggle = document.querySelector('.mobile-nav-toggle');
const navLinks = document.querySelector('.nav-links');
mobileNavToggle.addEventListener('click', () => {
navLinks.classList.toggle('active');
mobileNavToggle.classList.toggle('active');
});
document.addEventListener('click', (e) => {
if (!e.target.closest('.nav') && navLinks.classList.contains('active')) {
navLinks.classList.remove('active');
mobileNavToggle.classList.remove('active');
}
});
document.querySelectorAll('a[href^="#"]').forEach(anchor => {
anchor.addEventListener('click', function (e) {
e.preventDefault();
document.querySelector(this.getAttribute('href')).scrollIntoView({
behavior: 'smooth'
});
});
});
const form = document.querySelector('.contact-form');
form.addEventListener('submit', (e) => {
e.preventDefault();
const formData = new FormData(form);
const data = Object.fromEntries(formData);
console.log('Form submitted:', data);
form.reset();
});
let lastScroll = 0;
window.addEventListener('load', () => {
setTimeout(() => {
document.querySelector('.loading').classList.add('loaded');
}, 500);
});
const scrollTopBtn = document.querySelector('.scroll-top');
scrollTopBtn.addEventListener('click', () => {
window.scrollTo({
top: 0,
behavior: 'smooth'
});
});
const observeElements = () => {
const observer = new IntersectionObserver((entries) => {
entries.forEach(entry => {
if (entry.isIntersecting) {
entry.target.style.opacity = '1';
entry.target.style.transform = 'translateY(0)';
}
});
}, { threshold: 0.1 });
document.querySelectorAll('.feature-card').forEach(card => {
observer.observe(card);
card.style.opacity = '0';
card.style.transform = 'translateY(20px)';
card.style.transition = 'all 0.6s ease';
});
};
document.addEventListener('DOMContentLoaded', observeElements);
window.addEventListener('scroll', () => {
scrollTopBtn.classList.toggle('visible', window.scrollY > 500);
const header = document.querySelector('.header');
const currentScroll = window.pageYOffset;
if (currentScroll <= 0) {
header.classList.remove('scrolled');
return;
}
if (currentScroll > lastScroll && !header.classList.contains('scrolled')) {
header.classList.add('scrolled');
} else if (currentScroll < lastScroll && header.classList.contains('scrolled')) {
header.classList.remove('scrolled');
}
lastScroll = currentScroll;
});