mirror of https://github.com/kortix-ai/suna.git
api server, thread ws, api factory
This commit is contained in:
parent
a455980807
commit
e72b15c728
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -8,23 +8,22 @@ This agent can:
|
|||
- Use either XML or Standard tool calling patterns
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from tools.files_tool import FilesTool
|
||||
from example.tools.files_tool import FilesTool
|
||||
from agentpress.state_manager import StateManager
|
||||
from tools.terminal_tool import TerminalTool
|
||||
from agentpress.api_factory import register_api_endpoint
|
||||
from example.tools.terminal_tool import TerminalTool
|
||||
import logging
|
||||
from typing import AsyncGenerator, Optional, Dict, Any
|
||||
import sys
|
||||
|
||||
from agentpress.api.api_factory import register_thread_task_api
|
||||
|
||||
BASE_SYSTEM_MESSAGE = """
|
||||
You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
|
||||
You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
|
||||
codebase - do not omit previous progress; instead, keep iterating.
|
||||
|
||||
Available tools:
|
||||
- create_file: Create new files with specified content
|
||||
- delete_file: Remove existing files
|
||||
|
@ -69,7 +68,6 @@ Example workspace state for a file:
|
|||
}
|
||||
}
|
||||
Think deeply and step by step.
|
||||
|
||||
"""
|
||||
|
||||
XML_FORMAT = """
|
||||
|
@ -88,87 +86,77 @@ file contents here
|
|||
<delete-file file_path="path/to/file">
|
||||
</delete-file>
|
||||
|
||||
<stop_session></stop_session>
|
||||
"""
|
||||
|
||||
def get_anthropic_api_key():
|
||||
"""Get Anthropic API key from environment or prompt user."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
api_key = input("\n🔑 Please enter your Anthropic API key: ").strip()
|
||||
if not api_key:
|
||||
print("❌ No API key provided. Please set ANTHROPIC_API_KEY environment variable or enter a key.")
|
||||
sys.exit(1)
|
||||
os.environ["ANTHROPIC_API_KEY"] = api_key
|
||||
return api_key
|
||||
|
||||
@register_api_endpoint("/main_agent")
|
||||
@register_thread_task_api("/agent")
|
||||
async def run_agent(
|
||||
thread_id: str,
|
||||
use_xml: bool = True,
|
||||
max_iterations: int = 5,
|
||||
project_description: Optional[str] = None
|
||||
user_input: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the development agent with specified configuration."""
|
||||
# Initialize managers
|
||||
thread_manager = ThreadManager()
|
||||
await thread_manager.initialize()
|
||||
"""Run the development agent with specified configuration.
|
||||
|
||||
state_manager = StateManager(thread_id)
|
||||
await state_manager.initialize()
|
||||
Args:
|
||||
thread_id (str): The ID of the thread.
|
||||
max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
|
||||
user_input (Optional[str], optional): The user input. Defaults to None.
|
||||
"""
|
||||
thread_manager = ThreadManager()
|
||||
state_manager = StateManager(thread_id)
|
||||
|
||||
# Register tools
|
||||
thread_manager.add_tool(FilesTool, thread_id=thread_id)
|
||||
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
|
||||
|
||||
# Add initial project description if provided
|
||||
if project_description:
|
||||
if user_input:
|
||||
await thread_manager.add_message(
|
||||
thread_id,
|
||||
{
|
||||
"role": "user",
|
||||
"content": project_description
|
||||
"content": user_input
|
||||
}
|
||||
)
|
||||
|
||||
# Set up system message with appropriate format
|
||||
thread_manager.add_tool(FilesTool, thread_id=thread_id)
|
||||
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
|
||||
"content": BASE_SYSTEM_MESSAGE + XML_FORMAT
|
||||
}
|
||||
|
||||
# Create initial event to track agent loop
|
||||
await thread_manager.create_event(
|
||||
thread_id=thread_id,
|
||||
event_type="agent_loop_started",
|
||||
content={
|
||||
"max_iterations": max_iterations,
|
||||
"use_xml": use_xml,
|
||||
"project_description": project_description
|
||||
},
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
results = []
|
||||
iteration = 0
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
files_tool = FilesTool(thread_id)
|
||||
await files_tool._init_workspace_state()
|
||||
files_tool = FilesTool(thread_id=thread_id)
|
||||
|
||||
state = await state_manager.get_latest_state()
|
||||
|
||||
state_message = {
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Current development environment workspace state:
|
||||
<current_workspace_state>
|
||||
{json.dumps(state, indent=2)}
|
||||
</current_workspace_state>
|
||||
"""
|
||||
state = await state_manager.export_store()
|
||||
|
||||
temporary_message_content = f"""
|
||||
You are tasked to complete the LATEST USER REQUEST!
|
||||
<latest_user_request>
|
||||
{user_input}
|
||||
</latest_user_request>
|
||||
|
||||
Current development environment workspace state:
|
||||
<current_workspace_state>
|
||||
{json.dumps(state, indent=2) if state else "{}"}
|
||||
</current_workspace_state>
|
||||
|
||||
CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
|
||||
"""
|
||||
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=temporary_message_content,
|
||||
message_type="temporary_message",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
temporary_message = {
|
||||
"role": "user",
|
||||
"content": temporary_message_content
|
||||
}
|
||||
|
||||
model_name = "anthropic/claude-3-5-sonnet-latest"
|
||||
model_name = "anthropic/claude-3-5-sonnet-latest"
|
||||
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
|
@ -177,51 +165,57 @@ Current development environment workspace state:
|
|||
temperature=0.1,
|
||||
max_tokens=8096,
|
||||
tool_choice="auto",
|
||||
temporary_message=state_message,
|
||||
native_tool_calling=not use_xml,
|
||||
xml_tool_calling=use_xml,
|
||||
temporary_message=temporary_message,
|
||||
native_tool_calling=False,
|
||||
xml_tool_calling=True,
|
||||
stream=True,
|
||||
execute_tools_on_stream=False,
|
||||
parallel_tool_execution=True
|
||||
execute_tools_on_stream=True,
|
||||
parallel_tool_execution=True,
|
||||
)
|
||||
|
||||
# Handle both streaming and regular responses
|
||||
if hasattr(response, '__aiter__'):
|
||||
chunks = []
|
||||
if isinstance(response, AsyncGenerator):
|
||||
print("\n🤖 Assistant is responding:")
|
||||
try:
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
if hasattr(chunk.choices[0], 'delta'):
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if hasattr(delta, 'content') and delta.content is not None:
|
||||
content = delta.content
|
||||
print(content, end='', flush=True)
|
||||
|
||||
# Check for open_files_in_editor tag and continue if found
|
||||
if '</open_files_in_editor>' in content:
|
||||
print("\n📂 Opening files in editor, continuing to next iteration...")
|
||||
continue
|
||||
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
|
||||
if tool_call.function.arguments:
|
||||
print(f" {tool_call.function.arguments}", end='', flush=True)
|
||||
|
||||
print("\n✨ Response completed\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
|
||||
logging.error(f"Error processing stream: {e}")
|
||||
raise
|
||||
response = chunks
|
||||
else:
|
||||
print("\nNon-streaming response received:", response)
|
||||
|
||||
results.append({
|
||||
"iteration": iteration,
|
||||
"response": response
|
||||
})
|
||||
# # Get latest assistant message and check for stop_session
|
||||
# latest_msg = await thread_manager.get_llm_history_messages(
|
||||
# thread_id=thread_id,
|
||||
# only_latest_assistant=True
|
||||
# )
|
||||
# if latest_msg and '</stop_session>' in latest_msg:
|
||||
# break
|
||||
|
||||
# Create iteration completion event
|
||||
await thread_manager.create_event(
|
||||
thread_id=thread_id,
|
||||
event_type="iteration_complete",
|
||||
content={
|
||||
"iteration_number": iteration,
|
||||
"max_iterations": max_iterations,
|
||||
# "state": state
|
||||
},
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"iterations": results,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n🚀 Welcome to AgentPress Web Developer Example!")
|
||||
|
||||
get_anthropic_api_key()
|
||||
print("\n🚀 Welcome to AgentPress!")
|
||||
|
||||
project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
|
||||
if not project_description.strip():
|
||||
|
@ -241,10 +235,27 @@ if __name__ == "__main__":
|
|||
print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
|
||||
print("Use Ctrl+C to stop the agent at any time.")
|
||||
|
||||
async def async_main():
|
||||
async def test_agent():
|
||||
thread_manager = ThreadManager()
|
||||
thread_id = await thread_manager.create_thread()
|
||||
logging.info(f"Created new thread: {thread_id}")
|
||||
await run_agent(thread_id, use_xml, project_description=project_description)
|
||||
|
||||
try:
|
||||
result = await run_agent(
|
||||
thread_id=thread_id,
|
||||
max_iterations=5,
|
||||
user_input=project_description,
|
||||
)
|
||||
print("\n✅ Test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
asyncio.run(async_main())
|
||||
try:
|
||||
asyncio.run(test_agent())
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {str(e)}")
|
||||
raise
|
|
@ -60,13 +60,8 @@ class FilesTool(Tool):
|
|||
os.makedirs(self.workspace, exist_ok=True)
|
||||
if thread_id:
|
||||
self.state_manager = StateManager(thread_id)
|
||||
asyncio.create_task(self._init_state())
|
||||
self.SNIPPET_LINES = 4 # Number of context lines to show around edits
|
||||
|
||||
async def _init_state(self):
|
||||
"""Initialize state manager and workspace state."""
|
||||
await self.state_manager.initialize()
|
||||
await self._init_workspace_state()
|
||||
asyncio.create_task(self._init_workspace_state())
|
||||
self.SNIPPET_LINES = 4
|
||||
|
||||
def _should_exclude_file(self, rel_path: str) -> bool:
|
||||
"""Check if a file should be excluded based on path, name, or extension"""
|
||||
|
@ -264,6 +259,9 @@ class FilesTool(Tool):
|
|||
new_content = content.replace(old_str, new_str)
|
||||
full_path.write_text(new_content)
|
||||
|
||||
# Update state after file modification
|
||||
await self._update_workspace_state()
|
||||
|
||||
# Show snippet around the edit
|
||||
replacement_line = content.split(old_str)[0].count('\n')
|
||||
start_line = max(0, replacement_line - self.SNIPPET_LINES)
|
|
@ -2,7 +2,6 @@ import os
|
|||
import asyncio
|
||||
import subprocess
|
||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.state_manager import StateManager
|
||||
from typing import Optional
|
||||
|
||||
class TerminalTool(Tool):
|
||||
|
@ -12,24 +11,6 @@ class TerminalTool(Tool):
|
|||
super().__init__()
|
||||
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
if thread_id:
|
||||
self.state_manager = StateManager(thread_id)
|
||||
asyncio.create_task(self._init_state())
|
||||
|
||||
async def _init_state(self):
|
||||
"""Initialize state manager."""
|
||||
await self.state_manager.initialize()
|
||||
|
||||
async def _update_command_history(self, command: str, output: str, success: bool):
|
||||
"""Update command history in state"""
|
||||
history = await self.state_manager.get("terminal_history") or []
|
||||
history.append({
|
||||
"command": command,
|
||||
"output": output,
|
||||
"success": success,
|
||||
"cwd": os.path.relpath(os.getcwd(), self.workspace)
|
||||
})
|
||||
await self.state_manager.set("terminal_history", history)
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
|
@ -76,12 +57,6 @@ class TerminalTool(Tool):
|
|||
error = stderr.decode() if stderr else ""
|
||||
success = process.returncode == 0
|
||||
|
||||
await self._update_command_history(
|
||||
command=command,
|
||||
output=output + error,
|
||||
success=success
|
||||
)
|
||||
|
||||
if success:
|
||||
return self.success_response({
|
||||
"output": output,
|
||||
|
@ -93,11 +68,6 @@ class TerminalTool(Tool):
|
|||
return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
|
||||
|
||||
except Exception as e:
|
||||
await self._update_command_history(
|
||||
command=command,
|
||||
output=str(e),
|
||||
success=False
|
||||
)
|
||||
return self.fail_response(f"Error executing command: {str(e)}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
|
@ -1,253 +0,0 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
import asyncio
|
||||
import json
|
||||
import uvicorn
|
||||
import logging
|
||||
import importlib
|
||||
from agentpress.api_factory import app as api_app, discover_tasks
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global managers
|
||||
thread_manager: Optional[ThreadManager] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for FastAPI application."""
|
||||
# Startup
|
||||
global thread_manager
|
||||
thread_manager = ThreadManager()
|
||||
await thread_manager.initialize()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
# Add any cleanup code here if needed
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="AgentPress API", lifespan=lifespan)
|
||||
|
||||
# Enable CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Import and mount the API Factory app
|
||||
try:
|
||||
# Run task discovery
|
||||
discover_tasks()
|
||||
logger.info("Task discovery completed")
|
||||
|
||||
# Mount the API Factory app at /tasks instead of root
|
||||
app.mount("/tasks", api_app)
|
||||
logger.info("Mounted API Factory app at /tasks")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up API Factory: {e}")
|
||||
raise
|
||||
|
||||
# WebSocket connection manager
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections for real-time thread updates."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, thread_id: str):
|
||||
"""Connect a WebSocket to a thread."""
|
||||
await websocket.accept()
|
||||
if thread_id not in self.active_connections:
|
||||
self.active_connections[thread_id] = []
|
||||
self.active_connections[thread_id].append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, thread_id: str):
|
||||
"""Disconnect a WebSocket from a thread."""
|
||||
if thread_id in self.active_connections:
|
||||
self.active_connections[thread_id].remove(websocket)
|
||||
if not self.active_connections[thread_id]:
|
||||
del self.active_connections[thread_id]
|
||||
|
||||
async def broadcast_to_thread(self, thread_id: str, message: dict):
|
||||
"""Broadcast a message to all connections in a thread."""
|
||||
if thread_id in self.active_connections:
|
||||
for connection in self.active_connections[thread_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except WebSocketDisconnect:
|
||||
self.disconnect(connection, thread_id)
|
||||
|
||||
# Initialize WebSocket manager
|
||||
ws_manager = WebSocketManager()
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class EventCreate(BaseModel):
|
||||
event_type: str
|
||||
content: Dict[str, Any]
|
||||
include_in_llm_message_history: bool = False
|
||||
llm_message: Optional[Dict[str, Any]] = None
|
||||
|
||||
class EventUpdate(BaseModel):
|
||||
content: Optional[Dict[str, Any]] = None
|
||||
include_in_llm_message_history: Optional[bool] = None
|
||||
llm_message: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ThreadEvents(BaseModel):
|
||||
only_llm_messages: bool = False
|
||||
event_types: Optional[List[str]] = None
|
||||
order_by: str = "created_at"
|
||||
order: str = "ASC"
|
||||
|
||||
# REST API Endpoints
|
||||
@app.post("/threads", response_model=dict, status_code=201)
|
||||
async def create_thread():
|
||||
"""Create a new thread."""
|
||||
thread_id = await thread_manager.create_thread()
|
||||
return {"thread_id": thread_id}
|
||||
|
||||
@app.delete("/threads/{thread_id}", status_code=204)
|
||||
async def delete_thread(thread_id: str):
|
||||
"""Delete a thread and all its events."""
|
||||
success = await thread_manager.delete_thread(thread_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
return None
|
||||
|
||||
@app.post("/threads/{thread_id}/events", response_model=dict, status_code=201)
|
||||
async def create_event(thread_id: str, event: EventCreate, background_tasks: BackgroundTasks):
|
||||
"""Create a new event in a thread."""
|
||||
# First verify thread exists
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
try:
|
||||
event_id = await thread_manager.create_event(
|
||||
thread_id=thread_id,
|
||||
event_type=event.event_type,
|
||||
content=event.content,
|
||||
include_in_llm_message_history=event.include_in_llm_message_history,
|
||||
llm_message=event.llm_message
|
||||
)
|
||||
# Broadcast to WebSocket connections
|
||||
background_tasks.add_task(
|
||||
ws_manager.broadcast_to_thread,
|
||||
thread_id,
|
||||
{"type": "event_created", "event_id": event_id, "event": event.dict()}
|
||||
)
|
||||
return {"event_id": event_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@app.delete("/threads/{thread_id}/events/{event_id}", status_code=204)
|
||||
async def delete_event(thread_id: str, event_id: str, background_tasks: BackgroundTasks):
|
||||
"""Delete a specific event."""
|
||||
# First verify thread exists
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Then verify event exists and belongs to thread
|
||||
if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
|
||||
raise HTTPException(status_code=404, detail="Event not found in this thread")
|
||||
|
||||
success = await thread_manager.delete_event(event_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete event")
|
||||
|
||||
# Broadcast to WebSocket connections
|
||||
background_tasks.add_task(
|
||||
ws_manager.broadcast_to_thread,
|
||||
thread_id,
|
||||
{"type": "event_deleted", "event_id": event_id}
|
||||
)
|
||||
return None
|
||||
|
||||
@app.patch("/threads/{thread_id}/events/{event_id}", status_code=200)
|
||||
async def update_event(thread_id: str, event_id: str, event: EventUpdate, background_tasks: BackgroundTasks):
|
||||
"""Update an existing event."""
|
||||
# First verify thread exists
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Then verify event exists and belongs to thread
|
||||
if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
|
||||
raise HTTPException(status_code=404, detail="Event not found in this thread")
|
||||
|
||||
success = await thread_manager.update_event(
|
||||
event_id=event_id,
|
||||
thread_id=thread_id,
|
||||
content=event.content,
|
||||
include_in_llm_message_history=event.include_in_llm_message_history,
|
||||
llm_message=event.llm_message
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update event")
|
||||
|
||||
# Broadcast to WebSocket connections
|
||||
background_tasks.add_task(
|
||||
ws_manager.broadcast_to_thread,
|
||||
thread_id,
|
||||
{"type": "event_updated", "event_id": event_id, "updates": event.dict(exclude_unset=True)}
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
@app.get("/threads/{thread_id}/events")
|
||||
async def get_thread_events(
|
||||
thread_id: str,
|
||||
only_llm_messages: bool = False,
|
||||
event_types: Optional[List[str]] = None,
|
||||
order_by: str = "created_at",
|
||||
order: str = "ASC"
|
||||
):
|
||||
"""Get events from a thread with filtering options."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
events = await thread_manager.get_thread_events(
|
||||
thread_id=thread_id,
|
||||
only_llm_messages=only_llm_messages,
|
||||
event_types=event_types,
|
||||
order_by=order_by,
|
||||
order=order
|
||||
)
|
||||
return {"events": events}
|
||||
|
||||
@app.get("/threads/{thread_id}/messages")
|
||||
async def get_thread_messages(thread_id: str):
|
||||
"""Get LLM-formatted messages from thread events."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
messages = await thread_manager.get_thread_llm_messages(thread_id)
|
||||
return {"messages": messages}
|
||||
|
||||
# WebSocket Endpoint
|
||||
@app.websocket("/ws/threads/{thread_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, thread_id: str):
|
||||
"""WebSocket endpoint for real-time thread updates."""
|
||||
# Verify thread exists before accepting connection
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
await websocket.close(code=4004, reason="Thread not found")
|
||||
return
|
||||
|
||||
await ws_manager.connect(websocket, thread_id)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_json()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket, thread_id)
|
||||
except Exception as e:
|
||||
await websocket.send_json({"type": "error", "detail": str(e)})
|
||||
ws_manager.disconnect(websocket, thread_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -0,0 +1,255 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
import asyncio
|
||||
import uvicorn
|
||||
import logging
|
||||
from agentpress.api.ws import ws_manager
|
||||
from agentpress.api.api_factory import (
|
||||
app as thread_task_app,
|
||||
register_thread_task_api,
|
||||
discover_tasks,
|
||||
thread_manager as task_thread_manager
|
||||
)
|
||||
# from agentpress.api_factory import app as api_app, discover_tasks
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global managers
|
||||
thread_manager: Optional[ThreadManager] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for FastAPI application."""
|
||||
# Startup
|
||||
global thread_manager
|
||||
thread_manager = ThreadManager()
|
||||
|
||||
# Share thread_manager with task API
|
||||
global task_thread_manager
|
||||
task_thread_manager = thread_manager
|
||||
|
||||
# Wait for DB initialization
|
||||
db = thread_manager.db
|
||||
if db._initialization_task:
|
||||
await db._initialization_task
|
||||
|
||||
# Run task discovery during startup
|
||||
discover_tasks()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
# Add any cleanup code here if needed
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="AgentPress API", lifespan=lifespan)
|
||||
|
||||
# Enable CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# # Import and mount the API Factory app
|
||||
# try:
|
||||
# # Run task discovery
|
||||
# # discover_tasks()
|
||||
# logger.info("Task discovery completed")
|
||||
|
||||
# # Mount the API Factory app at /tasks instead of root
|
||||
# app.mount("/tasks", api_app)
|
||||
# logger.info("Mounted API Factory app at /tasks")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error setting up API Factory: {e}")
|
||||
# raise
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class MessageCreate(BaseModel):
|
||||
"""Model for creating messages in a thread."""
|
||||
message_data: Union[str, Dict[str, Any]]
|
||||
images: Optional[List[Dict[str, Any]]] = None
|
||||
include_in_llm_message_history: bool = True
|
||||
message_type: Optional[str] = None
|
||||
|
||||
# REST API Endpoints
|
||||
@app.post("/threads", response_model=dict, status_code=201)
|
||||
async def create_thread():
|
||||
"""Create a new thread."""
|
||||
thread_id = await thread_manager.create_thread()
|
||||
return {"thread_id": thread_id}
|
||||
|
||||
@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201)
|
||||
async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks):
|
||||
"""Create a new message in a thread."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
try:
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message.message_data,
|
||||
images=message.images,
|
||||
include_in_llm_message_history=message.include_in_llm_message_history,
|
||||
message_type=message.message_type
|
||||
)
|
||||
|
||||
# Broadcast to WebSocket connections
|
||||
background_tasks.add_task(
|
||||
ws_manager.broadcast_to_thread,
|
||||
thread_id,
|
||||
{"type": "message_created", "message": message.dict()}
|
||||
)
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# TODO: BROKEN FOR SOME REASON – RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE
|
||||
@app.get("/threads/{thread_id}/llm_history_messages")
|
||||
async def get_thread_llm_messages(
|
||||
thread_id: str,
|
||||
hide_tool_msgs: bool = False,
|
||||
only_latest_assistant: bool = False,
|
||||
):
|
||||
"""Get messages from a thread with filtering options."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
messages = await thread_manager.get_llm_history_messages(
|
||||
thread_id=thread_id,
|
||||
hide_tool_msgs=hide_tool_msgs,
|
||||
only_latest_assistant=only_latest_assistant,
|
||||
)
|
||||
return {"messages": messages}
|
||||
|
||||
@app.get("/threads/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
message_types: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
offset: Optional[int] = 0,
|
||||
before_timestamp: Optional[str] = None,
|
||||
after_timestamp: Optional[str] = None,
|
||||
include_in_llm_message_history: Optional[bool] = None,
|
||||
order: str = "asc"
|
||||
):
|
||||
"""
|
||||
Get messages from a thread with comprehensive filtering options.
|
||||
|
||||
Args:
|
||||
thread_id: Thread identifier
|
||||
message_types: Optional list of message types to filter by
|
||||
limit: Maximum number of messages to return (default: 50)
|
||||
offset: Number of messages to skip for pagination
|
||||
before_timestamp: Optional filter for messages before timestamp
|
||||
after_timestamp: Optional filter for messages after timestamp
|
||||
include_in_llm_message_history: Optional filter for LLM history inclusion
|
||||
order: Sort order - "asc" or "desc"
|
||||
"""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
try:
|
||||
messages = await thread_manager.get_messages(
|
||||
thread_id=thread_id,
|
||||
message_types=message_types,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
before_timestamp=before_timestamp,
|
||||
after_timestamp=after_timestamp,
|
||||
include_in_llm_message_history=include_in_llm_message_history,
|
||||
order=order
|
||||
)
|
||||
return {"messages": messages}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling.
|
||||
|
||||
# WebSocket Endpoint
|
||||
@app.websocket("/threads/{thread_id}")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
thread_id: str,
|
||||
message_types: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
offset: Optional[int] = 0,
|
||||
before_timestamp: Optional[str] = None,
|
||||
after_timestamp: Optional[str] = None,
|
||||
include_in_llm_message_history: Optional[bool] = None,
|
||||
order: str = "desc"
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time thread updates with filtering and pagination.
|
||||
|
||||
Query Parameters:
|
||||
message_types: Optional list of message types to filter by
|
||||
limit: Maximum number of messages to return (default: 50)
|
||||
offset: Number of messages to skip (for pagination)
|
||||
before_timestamp: Optional timestamp to filter messages before
|
||||
after_timestamp: Optional timestamp to filter messages after
|
||||
include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
|
||||
order: Sort order - "asc" or "desc" (default: desc)
|
||||
"""
|
||||
try:
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
await websocket.close(code=4004, reason="Thread not found")
|
||||
return
|
||||
|
||||
await ws_manager.connect(websocket, thread_id)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get messages with all filters
|
||||
result = await thread_manager.get_messages(
|
||||
thread_id=thread_id,
|
||||
message_types=message_types,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
before_timestamp=before_timestamp,
|
||||
after_timestamp=after_timestamp,
|
||||
include_in_llm_message_history=include_in_llm_message_history,
|
||||
order=order
|
||||
)
|
||||
|
||||
# Send messages and pagination info
|
||||
await websocket.send_json({
|
||||
"type": "messages",
|
||||
"data": result
|
||||
})
|
||||
|
||||
# Poll every second
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket, thread_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"WebSocket error: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": str(e)
|
||||
})
|
||||
ws_manager.disconnect(websocket, thread_id)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"WebSocket connection error: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason=str(e))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Update the mounting of thread_task_app
|
||||
app.mount("/tasks", thread_task_app)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
Thread Task API Factory for registering and managing thread-associated long-running tasks.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import importlib
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Any, Optional, List, ForwardRef
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import create_model
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize managers at module level
|
||||
thread_manager: Optional[ThreadManager] = None
|
||||
_decorated_functions: Dict[str, Callable] = {}
|
||||
_running_tasks: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def find_project_root():
|
||||
"""Find the project root by looking for pyproject.toml"""
|
||||
current = os.path.abspath(os.path.dirname(__file__))
|
||||
while current != '/':
|
||||
if os.path.exists(os.path.join(current, 'pyproject.toml')):
|
||||
return current
|
||||
current = os.path.dirname(current)
|
||||
return None
|
||||
|
||||
def discover_tasks():
|
||||
"""
|
||||
Discover all decorated functions in the project.
|
||||
Scans from the project root (where pyproject.toml is located).
|
||||
"""
|
||||
logger.info("Starting task discovery")
|
||||
|
||||
# Find project root
|
||||
project_root = find_project_root()
|
||||
logger.info(f"Project root found at: {project_root}")
|
||||
|
||||
# Add project root to Python path if not already there
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Walk through all Python files in the project
|
||||
for root, _, files in os.walk(project_root):
|
||||
for file in files:
|
||||
if file.endswith('.py'):
|
||||
module_path = os.path.join(root, file)
|
||||
module_name = os.path.relpath(module_path, project_root)
|
||||
module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to import module: {module_name}")
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Inspect all module members
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isfunction(obj):
|
||||
# Check if this function has been decorated with register_thread_task_api
|
||||
if hasattr(obj, '__closure__') and obj.__closure__:
|
||||
for cell in obj.__closure__:
|
||||
if cell.cell_contents in _decorated_functions.values():
|
||||
path = next(
|
||||
p for p, f in _decorated_functions.items()
|
||||
if f == cell.cell_contents
|
||||
)
|
||||
_decorated_functions[path] = obj
|
||||
logger.info(f"Registered function: {obj.__name__} at path: {path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing {module_name}: {e}")
|
||||
|
||||
logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for FastAPI application."""
|
||||
global thread_manager
|
||||
|
||||
# Initialize ThreadManager if not already initialized
|
||||
if thread_manager is None:
|
||||
thread_manager = ThreadManager()
|
||||
# Wait for DB initialization
|
||||
if thread_manager.db._initialization_task:
|
||||
await thread_manager.db._initialization_task
|
||||
|
||||
# Run task discovery during startup
|
||||
discover_tasks()
|
||||
|
||||
yield
|
||||
# Cleanup if needed
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Thread Task API",
|
||||
description="API for managing thread-associated long-running tasks",
|
||||
openapi_tags=[{
|
||||
"name": "Generated Thread Tasks",
|
||||
"description": "Dynamically generated endpoints for thread-associated tasks"
|
||||
}],
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add middleware to ensure thread_manager is available
|
||||
@app.middleware("http")
|
||||
async def ensure_thread_manager(request: Request, call_next):
|
||||
"""Ensure thread_manager is initialized before handling requests."""
|
||||
global thread_manager
|
||||
if thread_manager is None:
|
||||
thread_manager = ThreadManager()
|
||||
if thread_manager.db._initialization_task:
|
||||
await thread_manager.db._initialization_task
|
||||
return await call_next(request)
|
||||
|
||||
def register_thread_task_api(path: str):
|
||||
"""
|
||||
Decorator to register a function as a thread-associated task API endpoint.
|
||||
The decorated function must have thread_id as its first parameter.
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
logger.info(f"Registering thread task API endpoint: {path} for function {func.__name__}")
|
||||
_decorated_functions[path] = func
|
||||
|
||||
# Validate that thread_id is the first parameter
|
||||
params = inspect.signature(func).parameters
|
||||
if 'thread_id' not in params:
|
||||
raise ValueError(f"Function {func.__name__} must have thread_id as a parameter")
|
||||
|
||||
# Create Pydantic model for function parameters
|
||||
model_fields = {}
|
||||
for name, param in params.items():
|
||||
if name == 'self': # Skip self parameter for methods
|
||||
continue
|
||||
|
||||
annotation = param.annotation
|
||||
if annotation == inspect.Parameter.empty:
|
||||
annotation = Any
|
||||
|
||||
# Convert string annotations to ForwardRef
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
|
||||
default = ... if param.default == inspect.Parameter.empty else param.default
|
||||
model_fields[name] = (annotation, default)
|
||||
|
||||
RequestModel = create_model(f'{func.__name__}Request', **model_fields)
|
||||
|
||||
# Register the start endpoint
|
||||
@app.post(
|
||||
f"{path}/start",
|
||||
response_model=dict,
|
||||
summary=f"Start {func.__name__}",
|
||||
description=f"Start a new {func.__name__} task associated with a thread",
|
||||
tags=["Generated Thread Tasks"]
|
||||
)
|
||||
async def start_task(params: RequestModel):
|
||||
logger.info(f"Starting task at {path}/start")
|
||||
|
||||
# Validate thread exists
|
||||
if not await thread_manager.thread_exists(params.thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
kwargs = params.dict()
|
||||
|
||||
# Create the task
|
||||
task = asyncio.create_task(func(**kwargs))
|
||||
|
||||
# Store task with thread association
|
||||
_running_tasks[task_id] = {
|
||||
"thread_id": params.thread_id,
|
||||
"task": task,
|
||||
"status": "running",
|
||||
"path": path,
|
||||
"started_at": asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
# Add task info to thread messages
|
||||
await thread_manager.add_message(
|
||||
thread_id=params.thread_id,
|
||||
message_data={
|
||||
"type": "task_started",
|
||||
"task_id": task_id,
|
||||
"path": path,
|
||||
"status": "running"
|
||||
},
|
||||
message_type="task_status",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return {"task_id": task_id}
|
||||
|
||||
# Register the stop endpoint
|
||||
@app.post(
|
||||
f"{path}/stop/{{task_id}}",
|
||||
response_model=dict,
|
||||
summary=f"Stop {func.__name__}",
|
||||
description=f"Stop a running {func.__name__} task",
|
||||
tags=["Generated Thread Tasks"]
|
||||
)
|
||||
async def stop_task(task_id: str):
|
||||
if task_id not in _running_tasks:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = _running_tasks[task_id]
|
||||
task_info["task"].cancel()
|
||||
task_info["status"] = "cancelled"
|
||||
|
||||
# Update thread with task cancellation
|
||||
await thread_manager.add_message(
|
||||
thread_id=task_info["thread_id"],
|
||||
message_data={
|
||||
"type": "task_stopped",
|
||||
"task_id": task_id,
|
||||
"path": task_info["path"],
|
||||
"status": "cancelled"
|
||||
},
|
||||
message_type="task_status",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return {"status": "stopped"}
|
||||
|
||||
# Register the status endpoint
|
||||
@app.get(
|
||||
f"{path}/status/{{task_id}}",
|
||||
response_model=dict,
|
||||
summary=f"Get {func.__name__} status",
|
||||
description=f"Get the status of a {func.__name__} task",
|
||||
tags=["Generated Thread Tasks"]
|
||||
)
|
||||
async def get_status(task_id: str):
|
||||
if task_id not in _running_tasks:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = _running_tasks[task_id]
|
||||
task = task_info["task"]
|
||||
|
||||
if task.done():
|
||||
try:
|
||||
result = task.result()
|
||||
status = "completed"
|
||||
if hasattr(result, '__aiter__'):
|
||||
status = "streaming"
|
||||
|
||||
# Update thread with task completion
|
||||
await thread_manager.add_message(
|
||||
thread_id=task_info["thread_id"],
|
||||
message_data={
|
||||
"type": "task_completed",
|
||||
"task_id": task_id,
|
||||
"path": task_info["path"],
|
||||
"status": status,
|
||||
"result": result if status == "completed" else None
|
||||
},
|
||||
message_type="task_status",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"result": result if status == "completed" else None
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return {"status": "cancelled"}
|
||||
except Exception as e:
|
||||
error_status = {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Update thread with task failure
|
||||
await thread_manager.add_message(
|
||||
thread_id=task_info["thread_id"],
|
||||
message_data={
|
||||
"type": "task_failed",
|
||||
"task_id": task_id,
|
||||
"path": task_info["path"],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
},
|
||||
message_type="task_status",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return error_status
|
||||
|
||||
return {"status": "running"}
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@app.get("/threads/{thread_id}/tasks")
|
||||
async def get_thread_tasks(thread_id: str):
|
||||
"""Get all tasks associated with a thread."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
thread_tasks = {
|
||||
task_id: {
|
||||
"path": info["path"],
|
||||
"status": info["status"],
|
||||
"started_at": info["started_at"]
|
||||
}
|
||||
for task_id, info in _running_tasks.items()
|
||||
if info["thread_id"] == thread_id
|
||||
}
|
||||
|
||||
return {"tasks": thread_tasks}
|
||||
|
||||
@app.delete("/threads/{thread_id}/tasks")
|
||||
async def stop_thread_tasks(thread_id: str):
|
||||
"""Stop all tasks associated with a thread."""
|
||||
if not await thread_manager.thread_exists(thread_id):
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
stopped_tasks = []
|
||||
for task_id, info in list(_running_tasks.items()):
|
||||
if info["thread_id"] == thread_id:
|
||||
info["task"].cancel()
|
||||
info["status"] = "cancelled"
|
||||
stopped_tasks.append(task_id)
|
||||
|
||||
# Update thread with task cancellations
|
||||
if stopped_tasks:
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
message_data={
|
||||
"type": "tasks_stopped",
|
||||
"task_ids": stopped_tasks,
|
||||
"status": "cancelled"
|
||||
},
|
||||
message_type="task_status",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
return {"stopped_tasks": stopped_tasks}
|
|
@ -0,0 +1,45 @@
|
|||
"""WebSocket management system for real-time updates."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections for real-time thread updates."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, thread_id: str):
|
||||
"""Connect a WebSocket to a thread."""
|
||||
await websocket.accept()
|
||||
if thread_id not in self.active_connections:
|
||||
self.active_connections[thread_id] = []
|
||||
self.active_connections[thread_id].append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket, thread_id: str):
|
||||
"""Disconnect a WebSocket from a thread."""
|
||||
if thread_id in self.active_connections:
|
||||
self.active_connections[thread_id].remove(websocket)
|
||||
if not self.active_connections[thread_id]:
|
||||
del self.active_connections[thread_id]
|
||||
|
||||
async def broadcast_to_thread(self, thread_id: str, message: dict):
|
||||
"""Broadcast a message to all connections in a thread."""
|
||||
if thread_id in self.active_connections:
|
||||
disconnected = []
|
||||
for connection in self.active_connections[thread_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except WebSocketDisconnect:
|
||||
disconnected.append(connection)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to send message to websocket: {e}")
|
||||
disconnected.append(connection)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for connection in disconnected:
|
||||
self.disconnect(connection, thread_id)
|
||||
|
||||
# Global WebSocket manager instance
|
||||
ws_manager = WebSocketManager()
|
|
@ -1,170 +0,0 @@
|
|||
"""
|
||||
API Factory for registering and managing FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import inspect
|
||||
import importlib
|
||||
import pkgutil
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Any, Optional, List
|
||||
from fastapi import FastAPI, BackgroundTasks, HTTPException
|
||||
from pydantic import create_model, BaseModel
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
_decorated_functions: Dict[str, Callable] = {}
|
||||
_running_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def register_api_endpoint(path: str):
|
||||
"""Decorator to register a function as an API endpoint with task management."""
|
||||
def decorator(func: Callable):
|
||||
logger.info(f"Registering API endpoint: {path} for function {func.__name__}")
|
||||
_decorated_functions[path] = func
|
||||
|
||||
# Create Pydantic model for function parameters
|
||||
params = inspect.signature(func).parameters
|
||||
model_fields = {
|
||||
name: (param.annotation if param.annotation != inspect.Parameter.empty else Any, ... if param.default == inspect.Parameter.empty else param.default)
|
||||
for name, param in params.items()
|
||||
if name != 'self' # Skip self parameter for methods
|
||||
}
|
||||
RequestModel = create_model(f'{func.__name__}Request', **model_fields)
|
||||
|
||||
# Register the start endpoint
|
||||
@app.post(f"{path}/start", response_model=dict)
|
||||
async def start_task(params: Optional[RequestModel] = None):
|
||||
logger.info(f"Starting task at {path}/start")
|
||||
task_id = str(uuid.uuid4())
|
||||
kwargs = params.dict() if params else {}
|
||||
_running_tasks[task_id] = asyncio.create_task(func(**kwargs))
|
||||
return {"task_id": task_id}
|
||||
|
||||
# Register the stop endpoint
|
||||
@app.post(f"{path}/stop/{{task_id}}")
|
||||
async def stop_task(task_id: str):
|
||||
if task_id not in _running_tasks:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
_running_tasks[task_id].cancel()
|
||||
return {"status": "stopped"}
|
||||
|
||||
# Register the status endpoint
|
||||
@app.get(f"{path}/status/{{task_id}}")
|
||||
async def get_status(task_id: str):
|
||||
if task_id not in _running_tasks:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
task = _running_tasks[task_id]
|
||||
|
||||
if task.done():
|
||||
try:
|
||||
result = task.result()
|
||||
# Check if this is a streaming response
|
||||
if hasattr(result, '__aiter__') or (
|
||||
isinstance(result, dict) and (
|
||||
any(hasattr(v, '__aiter__') for v in result.values()) or
|
||||
# Check for streaming responses in iterations
|
||||
(
|
||||
'iterations' in result and
|
||||
result['iterations'] and
|
||||
any(hasattr(r.get('response'), '__aiter__') for r in result['iterations'])
|
||||
)
|
||||
)
|
||||
):
|
||||
return {"status": "streaming"}
|
||||
return {"status": "completed", "result": result}
|
||||
except asyncio.CancelledError:
|
||||
return {"status": "cancelled"}
|
||||
except Exception as e:
|
||||
return {"status": "failed", "error": str(e)}
|
||||
return {"status": "running"}
|
||||
|
||||
# Also register a direct endpoint for simple calls
|
||||
@app.post(path)
|
||||
async def direct_call(background_tasks: BackgroundTasks, params: Optional[RequestModel] = None):
|
||||
kwargs = params.dict() if params else {}
|
||||
task_id = str(uuid.uuid4())
|
||||
task = asyncio.create_task(func(**kwargs))
|
||||
_running_tasks[task_id] = task
|
||||
return {"task_id": task_id}
|
||||
|
||||
logger.info(f"Successfully registered all endpoints for {path}")
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def find_project_root() -> str:
|
||||
"""Find the project root by looking for pyproject.toml."""
|
||||
current = os.path.abspath(os.path.dirname(__file__))
|
||||
while current != '/':
|
||||
if os.path.exists(os.path.join(current, 'pyproject.toml')):
|
||||
return current
|
||||
current = os.path.dirname(current)
|
||||
return os.path.dirname(__file__) # Fallback to current directory
|
||||
|
||||
def discover_tasks():
|
||||
"""
|
||||
Discover all decorated functions in the project.
|
||||
Scans from the project root (where pyproject.toml is located).
|
||||
"""
|
||||
logger.info("Starting task discovery")
|
||||
|
||||
# Find project root
|
||||
project_root = find_project_root()
|
||||
logger.info(f"Project root found at: {project_root}")
|
||||
|
||||
# Add project root to Python path if not already there
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Walk through all Python files in the project
|
||||
for root, _, files in os.walk(project_root):
|
||||
for file in files:
|
||||
if file.endswith('.py'):
|
||||
module_path = os.path.join(root, file)
|
||||
module_name = os.path.relpath(module_path, project_root)
|
||||
module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to import module: {module_name}")
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Inspect all module members
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isfunction(obj):
|
||||
# Check if this function has been decorated with @task
|
||||
if any(
|
||||
path for path, func in _decorated_functions.items()
|
||||
if func.__name__ == obj.__name__ and func.__module__ == obj.__module__
|
||||
):
|
||||
logger.info(f"Found already registered function: {obj.__name__}")
|
||||
continue
|
||||
|
||||
# Check for our decorator in the function's closure
|
||||
if hasattr(obj, '__closure__') and obj.__closure__:
|
||||
for cell in obj.__closure__:
|
||||
if cell.cell_contents in _decorated_functions.values():
|
||||
# Found a decorated function that wasn't registered
|
||||
path = next(
|
||||
p for p, f in _decorated_functions.items()
|
||||
if f == cell.cell_contents
|
||||
)
|
||||
_decorated_functions[path] = obj
|
||||
logger.info(f"Registered function: {obj.__name__} at path: {path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing {module_name}: {e}")
|
||||
|
||||
logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
|
||||
|
||||
# Auto-discover tasks on import
|
||||
discover_tasks()
|
|
@ -2,99 +2,44 @@ import os
|
|||
import shutil
|
||||
import click
|
||||
import questionary
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from typing import Dict
|
||||
import time
|
||||
import pkg_resources
|
||||
import requests
|
||||
from packaging import version
|
||||
import re
|
||||
|
||||
MODULES = {
|
||||
"llm": {
|
||||
"required": True,
|
||||
"files": ["llm.py"],
|
||||
"description": "LLM Interface - Core module for interacting with large language models (OpenAI, Anthropic, 100+ LLMs using the OpenAI Input/Output Format powered by LiteLLM). Handles API calls, response streaming, and model-specific configurations."
|
||||
},
|
||||
"tool": {
|
||||
"required": True,
|
||||
"files": [
|
||||
"tool.py",
|
||||
"tool_registry.py"
|
||||
],
|
||||
"description": "Tool System Foundation - Defines the base architecture for creating and managing tools. Includes the tool registry for registering, organizing, and accessing tool functions."
|
||||
},
|
||||
"processors": {
|
||||
"required": True,
|
||||
"files": [
|
||||
"processor/base_processors.py",
|
||||
"processor/llm_response_processor.py",
|
||||
"processor/standard/standard_tool_parser.py",
|
||||
"processor/standard/standard_tool_executor.py",
|
||||
"processor/standard/standard_results_adder.py",
|
||||
"processor/xml/xml_tool_parser.py",
|
||||
"processor/xml/xml_tool_executor.py",
|
||||
"processor/xml/xml_results_adder.py"
|
||||
],
|
||||
"description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
|
||||
},
|
||||
"thread_management": {
|
||||
"required": True,
|
||||
"files": [
|
||||
"thread_manager.py",
|
||||
"thread_viewer_ui.py"
|
||||
],
|
||||
"description": "Conversation Management System - Handles message threading, conversation history, and provides a UI for viewing conversation threads. Manages the flow of messages between the user, LLM, and tools."
|
||||
},
|
||||
"state_management": {
|
||||
"required": True,
|
||||
"files": ["state_manager.py"],
|
||||
"description": "State Persistence System - Provides thread-safe storage and retrieval of conversation state, tool data, and other persistent information. Enables maintaining context across sessions and managing shared state between components."
|
||||
},
|
||||
"db_connection": {
|
||||
"required": True,
|
||||
"files": ["db_connection.py"],
|
||||
"description": "Database Connection - Provides a connection to a SQLite database for storing and retrieving conversation state, tool data, and other persistent information."
|
||||
}
|
||||
}
|
||||
PACKAGE_NAME = "agentpress"
|
||||
PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
|
||||
|
||||
STARTER_EXAMPLES = {
|
||||
"simple_web_dev_example_agent": {
|
||||
"description": "Interactive web development agent with file and terminal manipulation capabilities. Demonstrates both standard and XML-based tool calling patterns.",
|
||||
"files": {
|
||||
"agent.py": "examples/simple_web_dev/agent.py",
|
||||
"tools/files_tool.py": "examples/simple_web_dev/tools/files_tool.py",
|
||||
"tools/terminal_tool.py": "examples/simple_web_dev/tools/terminal_tool.py",
|
||||
".env.example": "examples/.env.example"
|
||||
"agent.py": "agents/simple_web_dev/agent.py",
|
||||
"tools/files_tool.py": "agents/simple_web_dev/tools/files_tool.py",
|
||||
"tools/terminal_tool.py": "agents/simple_web_dev/tools/terminal_tool.py",
|
||||
".env.example": "agents/.env.example"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PACKAGE_NAME = "agentpress"
|
||||
PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
|
||||
|
||||
def check_for_updates() -> Tuple[Optional[str], Optional[str], bool]:
|
||||
"""
|
||||
Check if there's a newer version available on PyPI
|
||||
Returns: (current_version, latest_version, update_available)
|
||||
"""
|
||||
def check_for_updates():
|
||||
"""Check if there's a newer version available on PyPI"""
|
||||
try:
|
||||
current_version = pkg_resources.get_distribution(PACKAGE_NAME).version
|
||||
response = requests.get(PYPI_URL, timeout=2)
|
||||
response.raise_for_status() # Raise exception for bad status codes
|
||||
response.raise_for_status()
|
||||
|
||||
latest_version = response.json()["info"]["version"]
|
||||
|
||||
# Compare versions properly using packaging.version
|
||||
current_ver = version.parse(current_version)
|
||||
latest_ver = version.parse(latest_version)
|
||||
|
||||
return current_version, latest_version, latest_ver > current_ver
|
||||
|
||||
except requests.RequestException:
|
||||
# Handle network-related errors silently
|
||||
return None, None, False
|
||||
except Exception as e:
|
||||
# Log other unexpected errors but don't break the CLI
|
||||
click.echo(f"Warning: Failed to check for updates: {str(e)}", err=True)
|
||||
return None, None, False
|
||||
|
||||
|
@ -102,7 +47,6 @@ def show_welcome():
|
|||
"""Display welcome message with ASCII art"""
|
||||
click.clear()
|
||||
|
||||
# Check for updates
|
||||
current_version, latest_version, update_available = check_for_updates()
|
||||
|
||||
click.echo("""
|
||||
|
@ -122,16 +66,17 @@ def show_welcome():
|
|||
|
||||
time.sleep(1)
|
||||
|
||||
def copy_module_files(src_dir: str, dest_dir: str, files: List[str]):
|
||||
"""Copy module files from package to destination"""
|
||||
def copy_package_files(src_dir: str, dest_dir: str):
|
||||
"""Copy all package files except agents folder to destination"""
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
with click.progressbar(files, label='Copying files') as file_list:
|
||||
for file in file_list:
|
||||
src = os.path.join(src_dir, file)
|
||||
dst = os.path.join(dest_dir, file)
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
shutil.copy2(src, dst)
|
||||
def ignore_patterns(path, names):
|
||||
# Ignore the agents directory and any __pycache__ directories
|
||||
return [n for n in names if n == 'agents' or n == '__pycache__']
|
||||
|
||||
with click.progressbar(length=1, label='Copying files') as bar:
|
||||
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True, ignore=ignore_patterns)
|
||||
bar.update(1)
|
||||
|
||||
def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
|
||||
"""Copy example files from package to destination"""
|
||||
|
@ -142,19 +87,6 @@ def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
|
|||
shutil.copy2(src, dst)
|
||||
click.echo(f" ✓ Created {dest_path}")
|
||||
|
||||
def update_file_paths(file_path: str, replacements: Dict[str, str]):
|
||||
"""Update file paths in the given file"""
|
||||
with open(file_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
for old, new in replacements.items():
|
||||
# Escape special characters in the old string
|
||||
escaped_old = re.escape(old)
|
||||
content = re.sub(escaped_old, new, content)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""AgentPress CLI - Initialize your AgentPress modules"""
|
||||
|
@ -165,7 +97,6 @@ def init():
|
|||
"""Initialize AgentPress modules in your project"""
|
||||
show_welcome()
|
||||
|
||||
# Set components directory name to 'agentpress'
|
||||
components_dir = "agentpress"
|
||||
|
||||
if os.path.exists(components_dir):
|
||||
|
@ -195,41 +126,20 @@ def init():
|
|||
# Get package directory
|
||||
package_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Show all modules status
|
||||
click.echo("\n🔧 AgentPress Modules Configuration\n")
|
||||
|
||||
# Show required modules including state_manager
|
||||
click.echo("📦 Required Modules (pre-selected):")
|
||||
required_modules = {name: module for name, module in MODULES.items()
|
||||
if module["required"] or name == "state_management"}
|
||||
for name, module in required_modules.items():
|
||||
click.echo(f" ✓ {click.style(name, fg='green')} - {module['description']}")
|
||||
|
||||
# Create selections dict with required modules pre-selected
|
||||
selections = {name: True for name in required_modules.keys()}
|
||||
|
||||
click.echo("\n🚀 Setting up your AgentPress...")
|
||||
time.sleep(0.5)
|
||||
|
||||
try:
|
||||
# Copy selected modules
|
||||
selected_modules = [name for name, selected in selections.items() if selected]
|
||||
all_files = []
|
||||
for module in selected_modules:
|
||||
all_files.extend(MODULES[module]["files"])
|
||||
|
||||
# Create components directory and copy module files
|
||||
# Create components directory and copy all files except agents folder
|
||||
components_dir_path = os.path.abspath(components_dir)
|
||||
copy_module_files(package_dir, components_dir_path, all_files)
|
||||
|
||||
copy_package_files(package_dir, components_dir_path)
|
||||
|
||||
|
||||
# Copy example only if a valid example (not None) was selected
|
||||
# Copy example if selected
|
||||
if selected_example and selected_example in STARTER_EXAMPLES:
|
||||
click.echo(f"\n📝 Creating {selected_example}...")
|
||||
copy_example_files(
|
||||
package_dir,
|
||||
os.getcwd(), # Use current working directory
|
||||
os.getcwd(),
|
||||
STARTER_EXAMPLES[selected_example]["files"]
|
||||
)
|
||||
|
||||
|
@ -246,7 +156,6 @@ def init():
|
|||
click.echo(f"\nRun the example agent:")
|
||||
click.echo(" python agent.py")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"\n❌ Error during setup: {str(e)}", err=True)
|
||||
return
|
||||
|
|
|
@ -7,105 +7,122 @@ import logging
|
|||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
class DBConnection:
|
||||
"""Singleton database connection manager."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
_db_path = "ap.db"
|
||||
_db_path = "/Users/markokraemer/Projects/softgen/softgen-core/main.db"
|
||||
_init_lock = asyncio.Lock()
|
||||
_initialization_task = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
# Start initialization when instance is first created
|
||||
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
|
||||
return cls._instance
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the database connection and schema."""
|
||||
if self._initialized:
|
||||
def __init__(self):
|
||||
"""No initialization needed in __init__ as it's handled in __new__"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _initialize(cls):
|
||||
"""Internal initialization method."""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Ensure the database directory exists
|
||||
os.makedirs(os.path.dirname(os.path.abspath(self._db_path)), exist_ok=True)
|
||||
|
||||
# Initialize database and create schema
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
async with cls._init_lock:
|
||||
if cls._initialized: # Double-check after acquiring lock
|
||||
return
|
||||
|
||||
# Create threads table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
try:
|
||||
async with aiosqlite.connect(cls._db_path) as db:
|
||||
# Threads table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Messages table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT,
|
||||
type TEXT,
|
||||
content TEXT,
|
||||
include_in_llm_message_history BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (thread_id) REFERENCES threads (id)
|
||||
)
|
||||
""")
|
||||
|
||||
await db.commit()
|
||||
cls._initialized = True
|
||||
logging.info("Database schema initialized")
|
||||
except Exception as e:
|
||||
logging.error(f"Database initialization error: {e}")
|
||||
raise
|
||||
|
||||
# Create events table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT,
|
||||
type TEXT,
|
||||
content TEXT,
|
||||
include_in_llm_message_history INTEGER DEFAULT 0,
|
||||
llm_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
@classmethod
|
||||
def set_db_path(cls, db_path: str):
|
||||
"""Set custom database path."""
|
||||
if cls._initialized:
|
||||
raise RuntimeError("Cannot change database path after initialization")
|
||||
cls._db_path = db_path
|
||||
logging.info(f"Updated database path to: {db_path}")
|
||||
|
||||
await db.commit()
|
||||
logging.info("Database initialized successfully")
|
||||
self._initialized = True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database: {e}")
|
||||
raise
|
||||
@asynccontextmanager
|
||||
async def connection(self):
|
||||
"""Get a database connection."""
|
||||
# Wait for initialization to complete if it hasn't already
|
||||
if self._initialization_task and not self._initialized:
|
||||
await self._initialization_task
|
||||
|
||||
async with aiosqlite.connect(self._db_path) as conn:
|
||||
try:
|
||||
yield conn
|
||||
except Exception as e:
|
||||
logging.error(f"Database error: {e}")
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Get a database connection with transaction support."""
|
||||
if not self._initialized:
|
||||
raise Exception("Database not initialized. Call initialize() first.")
|
||||
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
"""Execute operations in a transaction."""
|
||||
async with self.connection() as db:
|
||||
try:
|
||||
yield db
|
||||
await db.commit()
|
||||
logging.debug("Transaction committed successfully")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logging.error(f"Transaction failed, rolling back: {e}")
|
||||
logging.error(f"Transaction error: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, query: str, params: tuple = ()):
|
||||
"""Execute a query and return the cursor."""
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
return await db.execute(query, params)
|
||||
|
||||
async def fetch_all(self, query: str, params: tuple = ()):
|
||||
"""Execute a query and fetch all results."""
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
cursor = await db.execute(query, params)
|
||||
return await cursor.fetchall()
|
||||
"""Execute a single query."""
|
||||
async with self.connection() as db:
|
||||
try:
|
||||
result = await db.execute(query, params)
|
||||
await db.commit()
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Query execution error: {e}")
|
||||
raise
|
||||
|
||||
async def fetch_one(self, query: str, params: tuple = ()):
|
||||
"""Execute a query and fetch one result."""
|
||||
async with aiosqlite.connect(self._db_path) as db:
|
||||
await db.execute("PRAGMA foreign_keys = ON")
|
||||
cursor = await db.execute(query, params)
|
||||
return await cursor.fetchone()
|
||||
"""Fetch a single row."""
|
||||
async with self.connection() as db:
|
||||
async with db.execute(query, params) as cursor:
|
||||
return await cursor.fetchone()
|
||||
|
||||
def _serialize_json(self, data):
|
||||
"""Serialize data to JSON string."""
|
||||
return json.dumps(data) if data is not None else None
|
||||
|
||||
def _deserialize_json(self, data):
|
||||
"""Deserialize JSON string to data."""
|
||||
return json.loads(data) if data is not None else None
|
||||
async def fetch_all(self, query: str, params: tuple = ()):
|
||||
"""Fetch all rows."""
|
||||
async with self.connection() as db:
|
||||
async with db.execute(query, params) as cursor:
|
||||
return await cursor.fetchall()
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Dict, Any
|
||||
from typing import Union, Dict, Any, Optional, List
|
||||
import litellm
|
||||
import os
|
||||
import json
|
||||
|
@ -11,6 +11,14 @@ OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', None)
|
|||
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY', None)
|
||||
GROQ_API_KEY = os.environ.get('GROQ_API_KEY', None)
|
||||
AGENTOPS_API_KEY = os.environ.get('AGENTOPS_API_KEY', None)
|
||||
FIREWORKS_API_KEY = os.environ.get('FIREWORKS_AI_API_KEY', None)
|
||||
DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY', None)
|
||||
OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', None)
|
||||
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', None)
|
||||
|
||||
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
|
||||
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
|
||||
AWS_REGION_NAME = os.environ.get('AWS_REGION_NAME', None)
|
||||
|
||||
if OPENAI_API_KEY:
|
||||
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
|
||||
|
@ -18,6 +26,22 @@ if ANTHROPIC_API_KEY:
|
|||
os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
|
||||
if GROQ_API_KEY:
|
||||
os.environ['GROQ_API_KEY'] = GROQ_API_KEY
|
||||
if FIREWORKS_API_KEY:
|
||||
os.environ['FIREWORKS_AI_API_KEY'] = FIREWORKS_API_KEY
|
||||
if DEEPSEEK_API_KEY:
|
||||
os.environ['DEEPSEEK_API_KEY'] = DEEPSEEK_API_KEY
|
||||
if OPENROUTER_API_KEY:
|
||||
os.environ['OPENROUTER_API_KEY'] = OPENROUTER_API_KEY
|
||||
if GEMINI_API_KEY:
|
||||
os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY
|
||||
|
||||
# Add AWS environment variables if they exist
|
||||
if AWS_ACCESS_KEY_ID:
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = AWS_ACCESS_KEY_ID
|
||||
if AWS_SECRET_ACCESS_KEY:
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY
|
||||
if AWS_REGION_NAME:
|
||||
os.environ['AWS_REGION_NAME'] = AWS_REGION_NAME
|
||||
|
||||
async def make_llm_api_call(
|
||||
messages: list,
|
||||
|
@ -31,7 +55,8 @@ async def make_llm_api_call(
|
|||
api_base: str = None,
|
||||
agentops_session: Any = None,
|
||||
stream: bool = False,
|
||||
top_p: float = None
|
||||
top_p: float = None,
|
||||
stop: Optional[Union[str, List[str]]] = None # Add stop parameter
|
||||
) -> Union[Dict[str, Any], Any]:
|
||||
"""
|
||||
Make an API call to a language model using litellm.
|
||||
|
@ -52,6 +77,7 @@ async def make_llm_api_call(
|
|||
agentops_session (Any, optional): Session for agentops integration
|
||||
stream (bool, optional): Whether to stream the response. Defaults to False
|
||||
top_p (float, optional): Top-p sampling parameter
|
||||
stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], Any]: API response, either complete or streaming
|
||||
|
@ -59,7 +85,7 @@ async def make_llm_api_call(
|
|||
Raises:
|
||||
Exception: If API call fails after retries
|
||||
"""
|
||||
# litellm.set_verbose = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
async def attempt_api_call(api_call_func, max_attempts=3):
|
||||
"""
|
||||
|
@ -75,10 +101,17 @@ async def make_llm_api_call(
|
|||
Raises:
|
||||
Exception: If all retry attempts fail
|
||||
"""
|
||||
nonlocal model_name # Add this to access model_name
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await api_call_func()
|
||||
except litellm.exceptions.RateLimitError as e:
|
||||
# Check if it's Bedrock Claude and switch to direct Anthropic
|
||||
if "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" in model_name:
|
||||
logging.info("Rate limit hit with Bedrock Claude, falling back to direct Anthropic API...")
|
||||
model_name = "anthropic/claude-3-5-sonnet-latest"
|
||||
continue
|
||||
|
||||
logging.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
|
||||
await asyncio.sleep(30)
|
||||
except OpenAIError as e:
|
||||
|
@ -105,6 +138,10 @@ async def make_llm_api_call(
|
|||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add stop parameter if provided
|
||||
if stop is not None:
|
||||
api_call_params["stop"] = stop
|
||||
|
||||
# Add optional parameters if provided
|
||||
if api_key:
|
||||
api_call_params["api_key"] = api_key
|
||||
|
@ -129,6 +166,32 @@ async def make_llm_api_call(
|
|||
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
||||
}
|
||||
|
||||
# Add OpenRouter specific parameters
|
||||
if "openrouter" in model_name.lower():
|
||||
if settings.or_site_url:
|
||||
api_call_params["headers"] = {
|
||||
"HTTP-Referer": settings.or_site_url
|
||||
}
|
||||
if settings.or_app_name:
|
||||
api_call_params["headers"] = {
|
||||
"X-Title": settings.or_app_name
|
||||
}
|
||||
|
||||
# Add special handling for Deepseek
|
||||
if "deepseek" in model_name.lower():
|
||||
api_call_params["frequency_penalty"] = 0.5
|
||||
api_call_params["temperature"] = 0.7
|
||||
api_call_params["presence_penalty"] = 0.1
|
||||
|
||||
# Add Bedrock-specific parameters
|
||||
if "bedrock" in model_name.lower():
|
||||
if settings.aws_access_key_id:
|
||||
api_call_params["aws_access_key_id"] = settings.aws_access_key_id
|
||||
if settings.aws_secret_access_key:
|
||||
api_call_params["aws_secret_access_key"] = settings.aws_secret_access_key
|
||||
if settings.aws_region_name:
|
||||
api_call_params["aws_region_name"] = settings.aws_region_name
|
||||
|
||||
# Log the API request
|
||||
# logging.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
|
||||
|
||||
|
@ -137,10 +200,36 @@ async def make_llm_api_call(
|
|||
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
|
||||
else:
|
||||
response = await litellm.acompletion(**api_call_params)
|
||||
|
||||
# Log the API response
|
||||
|
||||
# logging.info(f"Received API response: {response}")
|
||||
|
||||
# # For streaming responses, attach cost tracking
|
||||
# if stream:
|
||||
# # Create a wrapper object to track costs across chunks
|
||||
# cost_tracker = {
|
||||
# "prompt_tokens": 0,
|
||||
# "completion_tokens": 0,
|
||||
# "total_tokens": 0,
|
||||
# "cost": 0.0
|
||||
# }
|
||||
|
||||
# # Get the cost per token for the model
|
||||
# model_cost = litellm.model_cost.get(model_name, {})
|
||||
# input_cost = model_cost.get('input_cost_per_token', 0)
|
||||
# output_cost = model_cost.get('output_cost_per_token', 0)
|
||||
|
||||
# # Attach the cost tracker to the response
|
||||
# response.cost_tracker = cost_tracker
|
||||
# response.model_info = {
|
||||
# "input_cost_per_token": input_cost,
|
||||
# "output_cost_per_token": output_cost
|
||||
# }
|
||||
# else:
|
||||
# # For non-streaming, cost is already included in the response
|
||||
# response._hidden_params = {
|
||||
# "response_cost": litellm.completion_cost(completion_response=response)
|
||||
# }
|
||||
|
||||
return response
|
||||
|
||||
return await attempt_api_call(api_call)
|
||||
|
@ -188,4 +277,37 @@ if __name__ == "__main__":
|
|||
print(response.choices[0].message.content)
|
||||
print()
|
||||
|
||||
asyncio.run(test_llm_api_call())
|
||||
# asyncio.run(test_llm_api_call())
|
||||
|
||||
async def test_bedrock():
|
||||
"""
|
||||
Test function for Bedrock API call.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello from Bedrock!"}
|
||||
]
|
||||
model_name = "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
response = await make_llm_api_call(messages, model_name, stream=True)
|
||||
|
||||
print("\n🤖 Streaming response from Bedrock:\n")
|
||||
buffer = ""
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, dict) and 'choices' in chunk:
|
||||
content = chunk['choices'][0]['delta'].get('content', '')
|
||||
else:
|
||||
content = chunk.choices[0].delta.content
|
||||
|
||||
if content:
|
||||
buffer += content
|
||||
if content[-1].isspace():
|
||||
print(buffer, end='', flush=True)
|
||||
buffer = ""
|
||||
|
||||
if buffer:
|
||||
print(buffer, flush=True)
|
||||
print("\n✨ Stream completed.\n")
|
||||
|
||||
# Add test_bedrock to the test runs
|
||||
# asyncio.run(test_bedrock())
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
|
|||
Attributes:
|
||||
add_message: Callback for adding new messages
|
||||
update_message: Callback for updating existing messages
|
||||
get_messages: Callback for retrieving thread messages
|
||||
get_llm_history_messages: Callback for retrieving thread messages
|
||||
message_added: Flag tracking if initial message has been added
|
||||
"""
|
||||
|
||||
|
@ -184,7 +184,7 @@ class ResultsAdderBase(ABC):
|
|||
"""
|
||||
self.add_message = thread_manager.add_message
|
||||
self.update_message = thread_manager._update_message
|
||||
self.get_messages = thread_manager.get_messages
|
||||
self.get_llm_history_messages = thread_manager.get_llm_history_messages
|
||||
self.message_added = False
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -9,12 +9,9 @@ This module provides comprehensive processing of LLM responses, including:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, Any, AsyncGenerator, Optional
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
import logging
|
||||
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
||||
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
|
||||
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
|
||||
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
|
||||
|
||||
class LLMResponseProcessor:
|
||||
"""Handles LLM response processing and tool execution management.
|
||||
|
@ -37,51 +34,30 @@ class LLMResponseProcessor:
|
|||
def __init__(
|
||||
self,
|
||||
thread_id: str,
|
||||
available_functions: Dict = None,
|
||||
add_message_callback: Callable = None,
|
||||
update_message_callback: Callable = None,
|
||||
get_messages_callback: Callable = None,
|
||||
parallel_tool_execution: bool = True,
|
||||
tool_parser: Optional[ToolParserBase] = None,
|
||||
tool_executor: Optional[ToolExecutorBase] = None,
|
||||
results_adder: Optional[ResultsAdderBase] = None,
|
||||
thread_manager = None
|
||||
tool_executor: ToolExecutorBase,
|
||||
tool_parser: ToolParserBase,
|
||||
available_functions: Dict,
|
||||
results_adder: ResultsAdderBase
|
||||
):
|
||||
"""Initialize the response processor.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the conversation thread
|
||||
available_functions: Dictionary of available tool functions
|
||||
add_message_callback: Callback for adding messages
|
||||
update_message_callback: Callback for updating messages
|
||||
get_messages_callback: Callback for listing messages
|
||||
parallel_tool_execution: Whether to execute tools in parallel
|
||||
tool_parser: Custom tool parser implementation
|
||||
tool_executor: Custom tool executor implementation
|
||||
tool_parser: Custom tool parser implementation
|
||||
available_functions: Dictionary of available tool functions
|
||||
results_adder: Custom results adder implementation
|
||||
thread_manager: Optional thread manager instance
|
||||
"""
|
||||
self.thread_id = thread_id
|
||||
self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
|
||||
self.tool_parser = tool_parser or StandardToolParser()
|
||||
self.available_functions = available_functions or {}
|
||||
|
||||
# Create minimal thread manager if needed
|
||||
if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
|
||||
class MinimalThreadManager:
|
||||
def __init__(self, add_msg, update_msg, list_msg):
|
||||
self.add_message = add_msg
|
||||
self._update_message = update_msg
|
||||
self.get_messages = list_msg
|
||||
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)
|
||||
|
||||
self.results_adder = results_adder or StandardResultsAdder(thread_manager)
|
||||
|
||||
# State tracking for streaming
|
||||
self.tool_calls_buffer = {}
|
||||
self.processed_tool_calls = set()
|
||||
self.tool_executor = tool_executor
|
||||
self.tool_parser = tool_parser
|
||||
self.available_functions = available_functions
|
||||
self.results_adder = results_adder
|
||||
self.content_buffer = ""
|
||||
self.tool_calls_buffer = {}
|
||||
self.tool_calls_accumulated = []
|
||||
self.processed_tool_calls = set()
|
||||
self._executing_tools = set() # Track currently executing tools
|
||||
|
||||
async def process_stream(
|
||||
self,
|
||||
|
@ -92,8 +68,9 @@ class LLMResponseProcessor:
|
|||
"""Process streaming LLM response and handle tool execution."""
|
||||
pending_tool_calls = []
|
||||
background_tasks = set()
|
||||
stream_completed = False # New flag to track stream completion
|
||||
|
||||
async def handle_message_management(chunk):
|
||||
async def handle_message_management(chunk, is_final=False):
|
||||
try:
|
||||
# Accumulate content
|
||||
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
||||
|
@ -106,28 +83,37 @@ class LLMResponseProcessor:
|
|||
self.tool_calls_buffer
|
||||
)
|
||||
if parsed_message and 'tool_calls' in parsed_message:
|
||||
self.tool_calls_accumulated = parsed_message['tool_calls']
|
||||
new_tool_calls = [
|
||||
tool_call for tool_call in parsed_message['tool_calls']
|
||||
if tool_call['id'] not in self.processed_tool_calls
|
||||
]
|
||||
if new_tool_calls:
|
||||
self.tool_calls_accumulated.extend(new_tool_calls)
|
||||
|
||||
# Handle tool execution and results
|
||||
if execute_tools and self.tool_calls_accumulated:
|
||||
new_tool_calls = [
|
||||
tool_call for tool_call in self.tool_calls_accumulated
|
||||
if tool_call['id'] not in self.processed_tool_calls
|
||||
if (tool_call['id'] not in self.processed_tool_calls and
|
||||
tool_call['id'] not in self._executing_tools)
|
||||
]
|
||||
|
||||
if new_tool_calls:
|
||||
if execute_tools_on_stream:
|
||||
for tool_call in new_tool_calls:
|
||||
self._executing_tools.add(tool_call['id'])
|
||||
|
||||
results = await self.tool_executor.execute_tool_calls(
|
||||
tool_calls=new_tool_calls,
|
||||
available_functions=self.available_functions,
|
||||
thread_id=self.thread_id,
|
||||
executed_tool_calls=self.processed_tool_calls
|
||||
)
|
||||
|
||||
for result in results:
|
||||
await self.results_adder.add_tool_result(self.thread_id, result)
|
||||
self.processed_tool_calls.add(result['tool_call_id'])
|
||||
else:
|
||||
pending_tool_calls.extend(new_tool_calls)
|
||||
self._executing_tools.discard(result['tool_call_id'])
|
||||
|
||||
# Add/update assistant message
|
||||
message = {
|
||||
|
@ -152,7 +138,10 @@ class LLMResponseProcessor:
|
|||
)
|
||||
|
||||
# Handle stream completion
|
||||
if chunk.choices[0].finish_reason:
|
||||
if chunk.choices[0].finish_reason or is_final:
|
||||
nonlocal stream_completed
|
||||
stream_completed = True
|
||||
|
||||
if not execute_tools_on_stream and pending_tool_calls:
|
||||
results = await self.tool_executor.execute_tool_calls(
|
||||
tool_calls=pending_tool_calls,
|
||||
|
@ -165,8 +154,16 @@ class LLMResponseProcessor:
|
|||
self.processed_tool_calls.add(result['tool_call_id'])
|
||||
pending_tool_calls.clear()
|
||||
|
||||
# Set final state on the chunk instead of returning it
|
||||
chunk._final_state = {
|
||||
"content": self.content_buffer,
|
||||
"tool_calls": self.tool_calls_accumulated,
|
||||
"processed_tool_calls": list(self.processed_tool_calls)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in background task: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
async for chunk in response_stream:
|
||||
|
@ -175,9 +172,22 @@ class LLMResponseProcessor:
|
|||
task.add_done_callback(background_tasks.discard)
|
||||
yield chunk
|
||||
|
||||
# Create a final dummy chunk to handle completion
|
||||
final_chunk = type('DummyChunk', (), {
|
||||
'choices': [type('DummyChoice', (), {
|
||||
'delta': type('DummyDelta', (), {'content': None}),
|
||||
'finish_reason': 'stop'
|
||||
})]
|
||||
})()
|
||||
|
||||
# Process final state
|
||||
await handle_message_management(final_chunk, is_final=True)
|
||||
yield final_chunk
|
||||
|
||||
# Wait for all background tasks to complete
|
||||
if background_tasks:
|
||||
await asyncio.gather(*background_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in stream processing: {e}")
|
||||
for task in background_tasks:
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -81,6 +81,6 @@ class StandardResultsAdder(ResultsAdderBase):
|
|||
- Checks for duplicate tool results before adding
|
||||
- Adds result only if tool_call_id is unique
|
||||
"""
|
||||
messages = await self.get_messages(thread_id)
|
||||
messages = await self.get_llm_history_messages(thread_id)
|
||||
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
|
||||
await self.add_message(thread_id, result)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -79,7 +79,7 @@ class XMLResultsAdder(ResultsAdderBase):
|
|||
"""
|
||||
try:
|
||||
# Get the original tool call to find the root tag
|
||||
messages = await self.get_messages(thread_id)
|
||||
messages = await self.get_llm_history_messages(thread_id)
|
||||
assistant_msg = next((msg for msg in reversed(messages)
|
||||
if msg['role'] == 'assistant'), None)
|
||||
|
||||
|
@ -107,10 +107,9 @@ class XMLResultsAdder(ResultsAdderBase):
|
|||
await self.add_message(thread_id, result_message)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding tool result: {e}")
|
||||
# Ensure the result is still added even if there's an error
|
||||
logging.error(f"Error adding tool result: {e}") # Ensure the result is still added even if there's an error
|
||||
result_message = {
|
||||
"role": "user",
|
||||
"content": f"Result for {result['name']}:\n{result['content']}"
|
||||
}
|
||||
await self.add_message(thread_id, result_message)
|
||||
await self.add_message(thread_id, result_message)
|
||||
|
|
|
@ -38,6 +38,8 @@ class XMLToolExecutor(ToolExecutorBase):
|
|||
"""
|
||||
self.parallel = parallel
|
||||
self.tool_registry = tool_registry or ToolRegistry()
|
||||
# Add internal tracking of executed tools
|
||||
self._executed_tools = set()
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
|
@ -65,20 +67,28 @@ class XMLToolExecutor(ToolExecutorBase):
|
|||
if executed_tool_calls is None:
|
||||
executed_tool_calls = set()
|
||||
|
||||
# Filter out already executed tool calls
|
||||
new_tool_calls = [
|
||||
tool_call for tool_call in tool_calls
|
||||
if tool_call['id'] not in executed_tool_calls
|
||||
and tool_call['id'] not in self._executed_tools
|
||||
]
|
||||
|
||||
if not new_tool_calls:
|
||||
return []
|
||||
|
||||
if self.parallel:
|
||||
return await self._execute_parallel(
|
||||
tool_calls,
|
||||
available_functions,
|
||||
thread_id,
|
||||
executed_tool_calls
|
||||
)
|
||||
results = await self._execute_parallel(new_tool_calls, available_functions, thread_id, executed_tool_calls)
|
||||
else:
|
||||
return await self._execute_sequential(
|
||||
tool_calls,
|
||||
available_functions,
|
||||
thread_id,
|
||||
executed_tool_calls
|
||||
)
|
||||
results = await self._execute_sequential(new_tool_calls, available_functions, thread_id, executed_tool_calls)
|
||||
|
||||
# Track executed tools internally
|
||||
for tool_call in new_tool_calls:
|
||||
self._executed_tools.add(tool_call['id'])
|
||||
if executed_tool_calls is not None:
|
||||
executed_tool_calls.add(tool_call['id'])
|
||||
|
||||
return results
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
|
@ -87,9 +97,10 @@ class XMLToolExecutor(ToolExecutorBase):
|
|||
thread_id: str,
|
||||
executed_tool_calls: Set[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if tool_call['id'] in executed_tool_calls:
|
||||
logging.info(f"Tool call {tool_call['id']} already executed")
|
||||
async def execute_single_tool(tool_call: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
# Double-check the tool hasn't been executed
|
||||
if (tool_call['id'] in executed_tool_calls or
|
||||
tool_call['id'] in self._executed_tools):
|
||||
return None
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,207 +1,118 @@
|
|||
"""
|
||||
Manages persistent state storage for AgentPress components using thread-based events.
|
||||
|
||||
The StateManager provides thread-safe access to state data stored as events in threads,
|
||||
allowing components to save and retrieve data across sessions. Each state update
|
||||
creates a new event containing the complete state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, List, Dict
|
||||
from asyncio import Lock
|
||||
import uuid
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
|
||||
class StateManager:
|
||||
"""
|
||||
Manages persistent state storage for AgentPress components using thread events.
|
||||
|
||||
The StateManager provides thread-safe access to state data stored as events,
|
||||
maintaining the complete state in each event for better consistency and tracking.
|
||||
|
||||
Attributes:
|
||||
lock (Lock): Asyncio lock for thread-safe state access
|
||||
thread_id (str): Thread ID for state storage
|
||||
thread_manager (ThreadManager): Thread manager instance for event handling
|
||||
Manages state storage using thread messages.
|
||||
Each state message contains a complete snapshot of the state at that point in time.
|
||||
"""
|
||||
|
||||
def __init__(self, thread_id: str):
|
||||
"""
|
||||
Initialize StateManager with thread ID.
|
||||
|
||||
Args:
|
||||
thread_id (str): Thread ID for state storage
|
||||
"""
|
||||
self.lock = Lock()
|
||||
self.thread_id = thread_id
|
||||
"""Initialize StateManager with a thread ID."""
|
||||
self.thread_manager = ThreadManager()
|
||||
logging.info(f"StateManager initialized with thread_id: {self.thread_id}")
|
||||
self.thread_id = thread_id
|
||||
self._state_cache = None
|
||||
logging.info(f"StateManager initialized for thread: {thread_id}")
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the thread manager."""
|
||||
await self.thread_manager.initialize()
|
||||
async def _get_state(self) -> Dict[str, Any]:
|
||||
"""Get the current complete state."""
|
||||
if self._state_cache is not None:
|
||||
return self._state_cache.copy() # Return copy to prevent cache mutation
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensure thread manager is initialized."""
|
||||
if not self.thread_manager.db._initialized:
|
||||
await self.initialize()
|
||||
|
||||
async def _get_current_state(self) -> dict:
|
||||
"""Get the current state from the latest state event."""
|
||||
await self._ensure_initialized()
|
||||
events = await self.thread_manager.get_thread_events(
|
||||
thread_id=self.thread_id,
|
||||
event_types=["state"],
|
||||
order_by="created_at",
|
||||
order="DESC"
|
||||
# Get the latest state message
|
||||
rows = await self.thread_manager.db.fetch_all(
|
||||
"""
|
||||
SELECT content
|
||||
FROM messages
|
||||
WHERE thread_id = ? AND type = 'state_message'
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
(self.thread_id,)
|
||||
)
|
||||
if events:
|
||||
return events[0]["content"].get("state", {})
|
||||
|
||||
if rows:
|
||||
try:
|
||||
self._state_cache = json.loads(rows[0][0])
|
||||
return self._state_cache.copy()
|
||||
except json.JSONDecodeError:
|
||||
logging.error("Failed to parse state JSON")
|
||||
|
||||
return {}
|
||||
|
||||
async def _save_state(self, state: dict):
|
||||
"""Save the complete state as a new event."""
|
||||
await self._ensure_initialized()
|
||||
await self.thread_manager.create_event(
|
||||
async def _save_state(self, state: Dict[str, Any]):
|
||||
"""Save a new complete state snapshot."""
|
||||
# Format state as a string with proper indentation
|
||||
formatted_state = json.dumps(state, indent=2)
|
||||
|
||||
# Save new state message with complete snapshot
|
||||
await self.thread_manager.add_message(
|
||||
thread_id=self.thread_id,
|
||||
event_type="state",
|
||||
content={"state": state}
|
||||
message_data=formatted_state,
|
||||
message_type='state_message',
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
# Update cache with a copy
|
||||
self._state_cache = state.copy()
|
||||
|
||||
async def set(self, key: str, data: Any) -> Any:
|
||||
"""
|
||||
Store data with a key in the state.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
data (Any): Any JSON-serializable data
|
||||
|
||||
Returns:
|
||||
Any: The stored data
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
current_state = await self._get_current_state()
|
||||
current_state[key] = data
|
||||
await self._save_state(current_state)
|
||||
logging.info(f'Updated state key: {key}')
|
||||
return data
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting state: {e}")
|
||||
raise
|
||||
"""Store any JSON-serializable data with a key."""
|
||||
state = await self._get_state()
|
||||
state[key] = data
|
||||
await self._save_state(state)
|
||||
logging.info(f'Updated state key: {key}')
|
||||
return data
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get data for a key from the current state.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
|
||||
Returns:
|
||||
Any: The stored data for the key, or None if key not found
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
current_state = await self._get_current_state()
|
||||
if key in current_state:
|
||||
logging.info(f'Retrieved key: {key}')
|
||||
return current_state[key]
|
||||
logging.info(f'Key not found: {key}')
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting state: {e}")
|
||||
raise
|
||||
"""Get data for a key."""
|
||||
state = await self._get_state()
|
||||
if key in state:
|
||||
data = state[key]
|
||||
logging.info(f'Retrieved key: {key}')
|
||||
return data
|
||||
logging.info(f'Key not found: {key}')
|
||||
return None
|
||||
|
||||
async def delete(self, key: str):
|
||||
"""
|
||||
Delete a key from the state.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
current_state = await self._get_current_state()
|
||||
if key in current_state:
|
||||
del current_state[key]
|
||||
await self._save_state(current_state)
|
||||
logging.info(f"Deleted key: {key}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting state: {e}")
|
||||
raise
|
||||
"""Delete data for a key."""
|
||||
state = await self._get_state()
|
||||
if key in state:
|
||||
del state[key]
|
||||
await self._save_state(state)
|
||||
logging.info(f"Deleted key: {key}")
|
||||
|
||||
async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]:
|
||||
"""
|
||||
Update existing dictionary data for a key by merging.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
data (Dict[str, Any]): Dictionary of updates to merge
|
||||
|
||||
Returns:
|
||||
Optional[Any]: Updated data if successful, None if key not found
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
current_state = await self._get_current_state()
|
||||
if key in current_state and isinstance(current_state[key], dict):
|
||||
current_state[key].update(data)
|
||||
await self._save_state(current_state)
|
||||
return current_state[key]
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating state: {e}")
|
||||
raise
|
||||
"""Update existing data for a key by merging dictionaries."""
|
||||
state = await self._get_state()
|
||||
if key in state and isinstance(state[key], dict):
|
||||
state[key].update(data)
|
||||
await self._save_state(state)
|
||||
logging.info(f'Updated state key: {key}')
|
||||
return state[key]
|
||||
return None
|
||||
|
||||
async def append(self, key: str, item: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
Append an item to a list stored at key.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
item (Any): Item to append
|
||||
|
||||
Returns:
|
||||
Optional[List[Any]]: Updated list if successful, None if key not found
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
current_state = await self._get_current_state()
|
||||
if key not in current_state:
|
||||
current_state[key] = []
|
||||
if isinstance(current_state[key], list):
|
||||
current_state[key].append(item)
|
||||
await self._save_state(current_state)
|
||||
return current_state[key]
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error appending to state: {e}")
|
||||
raise
|
||||
"""Append an item to a list stored at key."""
|
||||
state = await self._get_state()
|
||||
if key not in state:
|
||||
state[key] = []
|
||||
if isinstance(state[key], list):
|
||||
state[key].append(item)
|
||||
await self._save_state(state)
|
||||
logging.info(f'Appended to key: {key}')
|
||||
return state[key]
|
||||
return None
|
||||
|
||||
async def get_latest_state(self) -> dict:
|
||||
"""
|
||||
Get the latest complete state.
|
||||
|
||||
Returns:
|
||||
dict: Complete contents of the latest state
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
state = await self._get_current_state()
|
||||
logging.info(f"Retrieved latest state with {len(state)} keys")
|
||||
return state
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting latest state: {e}")
|
||||
raise
|
||||
async def export_store(self) -> dict:
|
||||
"""Export entire state."""
|
||||
state = await self._get_state()
|
||||
return state
|
||||
|
||||
async def clear_state(self):
|
||||
"""
|
||||
Clear the entire state.
|
||||
"""
|
||||
async with self.lock:
|
||||
try:
|
||||
await self._save_state({})
|
||||
logging.info("Cleared state")
|
||||
except Exception as e:
|
||||
logging.error(f"Error clearing state: {e}")
|
||||
raise
|
||||
async def clear_store(self):
|
||||
"""Clear entire state."""
|
||||
await self._save_state({})
|
||||
self._state_cache = {}
|
||||
logging.info("Cleared state")
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,64 +1,111 @@
|
|||
import streamlit as st
|
||||
from datetime import datetime
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.db_connection import DBConnection
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
def format_message_content(content):
|
||||
"""Format message content handling both string and list formats."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
formatted_content = []
|
||||
for item in content:
|
||||
if item.get('type') == 'text':
|
||||
formatted_content.append(item['text'])
|
||||
elif item.get('type') == 'image_url':
|
||||
formatted_content.append("[Image]")
|
||||
return "\n".join(formatted_content)
|
||||
return str(content)
|
||||
"""Format message content handling various formats."""
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
# Try to parse JSON strings
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, (dict, list)):
|
||||
return json.dumps(parsed, indent=2)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
formatted_content = []
|
||||
for item in content:
|
||||
if item.get('type') == 'text':
|
||||
formatted_content.append(item['text'])
|
||||
elif item.get('type') == 'image_url':
|
||||
formatted_content.append("[Image]")
|
||||
return "\n".join(formatted_content)
|
||||
return json.dumps(content, indent=2)
|
||||
except:
|
||||
return str(content)
|
||||
|
||||
async def load_threads():
|
||||
"""Load all thread IDs from the database."""
|
||||
db = DBConnection()
|
||||
rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC")
|
||||
rows = await db.fetch_all(
|
||||
"""
|
||||
SELECT id, created_at
|
||||
FROM threads
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
return rows
|
||||
|
||||
async def load_thread_content(thread_id: str):
|
||||
"""Load the content of a specific thread from the database."""
|
||||
thread_manager = ThreadManager()
|
||||
return await thread_manager.get_messages(thread_id)
|
||||
async def load_thread_content(thread_id: str, filters: dict):
|
||||
"""Load messages from a thread with filters."""
|
||||
db = DBConnection()
|
||||
|
||||
query_parts = ["SELECT type, content, include_in_llm_message_history, created_at FROM messages WHERE thread_id = ?"]
|
||||
params = [thread_id]
|
||||
|
||||
if filters.get('message_types'):
|
||||
# Convert comma-separated string to list and clean up whitespace
|
||||
types_list = [t.strip() for t in filters['message_types'].split(',') if t.strip()]
|
||||
if types_list:
|
||||
query_parts.append("AND type IN (" + ",".join(["?" for _ in types_list]) + ")")
|
||||
params.extend(types_list)
|
||||
|
||||
if filters.get('exclude_message_types'):
|
||||
# Convert comma-separated string to list and clean up whitespace
|
||||
exclude_types_list = [t.strip() for t in filters['exclude_message_types'].split(',') if t.strip()]
|
||||
if exclude_types_list:
|
||||
query_parts.append("AND type NOT IN (" + ",".join(["?" for _ in exclude_types_list]) + ")")
|
||||
params.extend(exclude_types_list)
|
||||
|
||||
if filters.get('before_timestamp'):
|
||||
query_parts.append("AND created_at < ?")
|
||||
params.append(filters['before_timestamp'])
|
||||
|
||||
if filters.get('after_timestamp'):
|
||||
query_parts.append("AND created_at > ?")
|
||||
params.append(filters['after_timestamp'])
|
||||
|
||||
if filters.get('include_in_llm_message_history') is not None:
|
||||
query_parts.append("AND include_in_llm_message_history = ?")
|
||||
params.append(filters['include_in_llm_message_history'])
|
||||
|
||||
# Add ordering
|
||||
order_direction = "DESC" if filters.get('order', 'asc').lower() == 'desc' else "ASC"
|
||||
query_parts.append(f"ORDER BY created_at {order_direction}")
|
||||
|
||||
# Add limit and offset
|
||||
if filters.get('limit'):
|
||||
query_parts.append("LIMIT ?")
|
||||
params.append(filters['limit'])
|
||||
|
||||
if filters.get('offset'):
|
||||
query_parts.append("OFFSET ?")
|
||||
params.append(filters['offset'])
|
||||
|
||||
query = " ".join(query_parts)
|
||||
rows = await db.fetch_all(query, tuple(params))
|
||||
return rows
|
||||
|
||||
def render_message(role, content, avatar):
|
||||
"""Render a message with a consistent chat-like style."""
|
||||
# Create columns for avatar and message
|
||||
col1, col2 = st.columns([1, 11])
|
||||
|
||||
# Style based on role
|
||||
if role == "assistant":
|
||||
bgcolor = "rgba(25, 25, 25, 0.05)"
|
||||
elif role == "user":
|
||||
bgcolor = "rgba(25, 120, 180, 0.05)"
|
||||
elif role == "system":
|
||||
bgcolor = "rgba(180, 25, 25, 0.05)"
|
||||
else:
|
||||
bgcolor = "rgba(100, 100, 100, 0.05)"
|
||||
|
||||
# Display avatar in first column
|
||||
def render_message(msg_type: str, content: str, include_in_llm: bool, timestamp: str):
|
||||
"""Render a message using Streamlit components."""
|
||||
# Message type and metadata
|
||||
col1, col2 = st.columns([3, 1])
|
||||
with col1:
|
||||
st.markdown(f"<div style='text-align: center; font-size: 24px;'>{avatar}</div>", unsafe_allow_html=True)
|
||||
|
||||
# Display message in second column
|
||||
st.text(f"Type: {msg_type}")
|
||||
with col2:
|
||||
st.markdown(
|
||||
f"""
|
||||
<div style='background-color: {bgcolor}; padding: 10px; border-radius: 5px;'>
|
||||
<strong>{role.upper()}</strong><br>
|
||||
{content}
|
||||
</div>
|
||||
""",
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
st.text("🟢 LLM" if include_in_llm else "⚫ Non-LLM")
|
||||
|
||||
# Timestamp
|
||||
st.text(f"Time: {datetime.fromisoformat(timestamp).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Message content
|
||||
st.code(content, language="json")
|
||||
|
||||
# Separator
|
||||
st.divider()
|
||||
|
||||
def main():
|
||||
st.title("Thread Viewer")
|
||||
|
@ -86,7 +133,6 @@ def main():
|
|||
)
|
||||
|
||||
if selected_thread_display:
|
||||
# Get the actual thread ID from the display string
|
||||
selected_thread_id = thread_options[selected_thread_display]
|
||||
|
||||
# Display thread ID in sidebar
|
||||
|
@ -95,46 +141,77 @@ def main():
|
|||
# Add refresh button
|
||||
if st.sidebar.button("🔄 Refresh Thread"):
|
||||
st.session_state.threads = asyncio.run(load_threads())
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
# Load and display messages
|
||||
messages = asyncio.run(load_thread_content(selected_thread_id))
|
||||
# Advanced filtering options in sidebar
|
||||
st.sidebar.title("Filter Options")
|
||||
|
||||
# Display messages in chat-like interface
|
||||
for message in messages:
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")
|
||||
|
||||
# Determine avatar based on role
|
||||
if role == "assistant":
|
||||
avatar = "🤖"
|
||||
elif role == "user":
|
||||
avatar = "👤"
|
||||
elif role == "system":
|
||||
avatar = "⚙️"
|
||||
elif role == "tool":
|
||||
avatar = "🔧"
|
||||
else:
|
||||
avatar = "❓"
|
||||
|
||||
# Format the content
|
||||
# Message type filters
|
||||
col1, col2 = st.sidebar.columns(2)
|
||||
with col1:
|
||||
message_types = st.text_input(
|
||||
"Include Types",
|
||||
help="Enter message types to include, separated by commas"
|
||||
)
|
||||
with col2:
|
||||
exclude_message_types = st.text_input(
|
||||
"Exclude Types",
|
||||
help="Enter message types to exclude, separated by commas"
|
||||
)
|
||||
|
||||
# Limit and offset
|
||||
col1, col2 = st.sidebar.columns(2)
|
||||
with col1:
|
||||
limit = st.number_input("Limit", min_value=1, value=50)
|
||||
with col2:
|
||||
offset = st.number_input("Offset", min_value=0, value=0)
|
||||
|
||||
# Timestamp filters
|
||||
st.sidebar.subheader("Time Range")
|
||||
before_timestamp = st.sidebar.date_input("Before Date", value=None)
|
||||
after_timestamp = st.sidebar.date_input("After Date", value=None)
|
||||
|
||||
# LLM history filter
|
||||
include_in_llm = st.sidebar.radio(
|
||||
"LLM History Filter",
|
||||
options=["All Messages", "LLM Only", "Non-LLM Only"]
|
||||
)
|
||||
|
||||
# Sort order
|
||||
order = st.sidebar.radio("Sort Order", ["Ascending", "Descending"])
|
||||
|
||||
# Prepare filters
|
||||
filters = {
|
||||
'message_types': message_types if message_types else None,
|
||||
'exclude_message_types': exclude_message_types if exclude_message_types else None,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
'order': 'desc' if order == "Descending" else 'asc'
|
||||
}
|
||||
|
||||
# Add timestamp filters if selected
|
||||
if before_timestamp:
|
||||
filters['before_timestamp'] = before_timestamp.isoformat()
|
||||
if after_timestamp:
|
||||
filters['after_timestamp'] = after_timestamp.isoformat()
|
||||
|
||||
# Add LLM history filter
|
||||
if include_in_llm == "LLM Only":
|
||||
filters['include_in_llm_message_history'] = True
|
||||
elif include_in_llm == "Non-LLM Only":
|
||||
filters['include_in_llm_message_history'] = False
|
||||
|
||||
# Load messages with filters
|
||||
messages = asyncio.run(load_thread_content(selected_thread_id, filters))
|
||||
|
||||
if not messages:
|
||||
st.info("No messages found with current filters")
|
||||
return
|
||||
|
||||
# Display messages
|
||||
for msg_type, content, include_in_llm, timestamp in messages:
|
||||
formatted_content = format_message_content(content)
|
||||
|
||||
# Render the message
|
||||
render_message(role, formatted_content, avatar)
|
||||
|
||||
# Display tool calls if present
|
||||
if "tool_calls" in message:
|
||||
with st.expander("🛠️ Tool Calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
st.code(
|
||||
f"Function: {tool_call['function']['name']}\n"
|
||||
f"Arguments: {tool_call['function']['arguments']}",
|
||||
language="json"
|
||||
)
|
||||
|
||||
# Add some spacing between messages
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
render_message(msg_type, formatted_content, include_in_llm, timestamp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Empty file to mark as package
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
Interactive web development agent supporting both XML and Standard LLM tool calling.
|
||||
|
||||
This agent can:
|
||||
- Create and modify web projects
|
||||
- Execute terminal commands
|
||||
- Handle file operations
|
||||
- Use either XML or Standard tool calling patterns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from example.tools.files_tool import FilesTool
|
||||
from agentpress.state_manager import StateManager
|
||||
from example.tools.terminal_tool import TerminalTool
|
||||
import logging
|
||||
from typing import AsyncGenerator, Optional, Dict, Any
|
||||
import sys
|
||||
|
||||
from agentpress.api.api_factory import register_thread_task_api
|
||||
|
||||
BASE_SYSTEM_MESSAGE = """
|
||||
You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
|
||||
You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
|
||||
codebase - do not omit previous progress; instead, keep iterating.
|
||||
Available tools:
|
||||
- create_file: Create new files with specified content
|
||||
- delete_file: Remove existing files
|
||||
- str_replace: Make precise text replacements in files
|
||||
- execute_command: Run terminal commands
|
||||
|
||||
|
||||
RULES:
|
||||
- All current file contents are available to you in the <current_workspace_state> section
|
||||
- Each file in the workspace state includes its full content
|
||||
- Use str_replace for precise replacements in files
|
||||
- NEVER include comments in any code you write - the code should be self-documenting
|
||||
- Always maintain the full context of files when making changes
|
||||
- When creating new files, write clean code without any comments or documentation
|
||||
|
||||
<available_tools>
|
||||
[create_file(file_path, file_contents)] - Create new files
|
||||
[delete_file(file_path)] - Delete existing files
|
||||
[str_replace(file_path, old_str, new_str)] - Replace specific text in files
|
||||
[execute_command(command)] - Execute terminal commands
|
||||
</available_tools>
|
||||
|
||||
ALWAYS RESPOND WITH MULTIPLE SIMULTANEOUS ACTIONS:
|
||||
<thoughts>
|
||||
[Provide a concise overview of your planned changes and implementations]
|
||||
</thoughts>
|
||||
|
||||
<actions>
|
||||
[Include multiple tool calls]
|
||||
</actions>
|
||||
|
||||
EDITING GUIDELINES:
|
||||
1. Review the current file contents in the workspace state
|
||||
2. Make targeted changes with str_replace
|
||||
3. Write clean, self-documenting code without comments
|
||||
4. Use create_file for new files and str_replace for modifications
|
||||
|
||||
Example workspace state for a file:
|
||||
{
|
||||
"index.html": {
|
||||
"content": "<!DOCTYPE html>\\n<html>\\n<head>..."
|
||||
}
|
||||
}
|
||||
Think deeply and step by step.
|
||||
"""
|
||||
|
||||
XML_FORMAT = """
|
||||
RESPONSE FORMAT:
|
||||
Use XML tags to specify file operations:
|
||||
|
||||
<create-file file_path="path/to/file">
|
||||
file contents here
|
||||
</create-file>
|
||||
|
||||
<str-replace file_path="path/to/file">
|
||||
<old_str>text to replace</old_str>
|
||||
<new_str>replacement text</new_str>
|
||||
</str-replace>
|
||||
|
||||
<delete-file file_path="path/to/file">
|
||||
</delete-file>
|
||||
|
||||
<stop_session></stop_session>
|
||||
"""
|
||||
|
||||
@register_thread_task_api("/agent")
|
||||
async def run_agent(
|
||||
thread_id: str,
|
||||
max_iterations: int = 5,
|
||||
user_input: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the development agent with specified configuration.
|
||||
|
||||
Args:
|
||||
thread_id (str): The ID of the thread.
|
||||
max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
|
||||
user_input (Optional[str], optional): The user input. Defaults to None.
|
||||
"""
|
||||
thread_manager = ThreadManager()
|
||||
state_manager = StateManager(thread_id)
|
||||
|
||||
if user_input:
|
||||
await thread_manager.add_message(
|
||||
thread_id,
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_input
|
||||
}
|
||||
)
|
||||
|
||||
thread_manager.add_tool(FilesTool, thread_id=thread_id)
|
||||
thread_manager.add_tool(TerminalTool, thread_id=thread_id)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": BASE_SYSTEM_MESSAGE + XML_FORMAT
|
||||
}
|
||||
|
||||
iteration = 0
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
files_tool = FilesTool(thread_id=thread_id)
|
||||
|
||||
state = await state_manager.export_store()
|
||||
|
||||
temporary_message_content = f"""
|
||||
You are tasked to complete the LATEST USER REQUEST!
|
||||
<latest_user_request>
|
||||
{user_input}
|
||||
</latest_user_request>
|
||||
|
||||
Current development environment workspace state:
|
||||
<current_workspace_state>
|
||||
{json.dumps(state, indent=2) if state else "{}"}
|
||||
</current_workspace_state>
|
||||
|
||||
CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
|
||||
"""
|
||||
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=temporary_message_content,
|
||||
message_type="temporary_message",
|
||||
include_in_llm_message_history=False
|
||||
)
|
||||
|
||||
temporary_message = {
|
||||
"role": "user",
|
||||
"content": temporary_message_content
|
||||
}
|
||||
|
||||
model_name = "anthropic/claude-3-5-sonnet-latest"
|
||||
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_message=system_message,
|
||||
model_name=model_name,
|
||||
temperature=0.1,
|
||||
max_tokens=8096,
|
||||
tool_choice="auto",
|
||||
temporary_message=temporary_message,
|
||||
native_tool_calling=False,
|
||||
xml_tool_calling=True,
|
||||
stream=True,
|
||||
execute_tools_on_stream=True,
|
||||
parallel_tool_execution=True,
|
||||
)
|
||||
|
||||
if isinstance(response, AsyncGenerator):
|
||||
print("\n🤖 Assistant is responding:")
|
||||
try:
|
||||
async for chunk in response:
|
||||
if hasattr(chunk.choices[0], 'delta'):
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if hasattr(delta, 'content') and delta.content is not None:
|
||||
content = delta.content
|
||||
print(content, end='', flush=True)
|
||||
|
||||
# Check for open_files_in_editor tag and continue if found
|
||||
if '</open_files_in_editor>' in content:
|
||||
print("\n📂 Opening files in editor, continuing to next iteration...")
|
||||
continue
|
||||
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
|
||||
if tool_call.function.arguments:
|
||||
print(f" {tool_call.function.arguments}", end='', flush=True)
|
||||
|
||||
print("\n✨ Response completed\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
|
||||
logging.error(f"Error processing stream: {e}")
|
||||
else:
|
||||
print("\nNon-streaming response received:", response)
|
||||
|
||||
# # Get latest assistant message and check for stop_session
|
||||
# latest_msg = await thread_manager.get_llm_history_messages(
|
||||
# thread_id=thread_id,
|
||||
# only_latest_assistant=True
|
||||
# )
|
||||
# if latest_msg and '</stop_session>' in latest_msg:
|
||||
# break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n🚀 Welcome to AgentPress!")
|
||||
|
||||
project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
|
||||
if not project_description.strip():
|
||||
project_description = "Create a modern, responsive landing page"
|
||||
|
||||
print("\nChoose your agent type:")
|
||||
print("1. XML-based Tool Calling")
|
||||
print(" - Structured XML format for tool execution")
|
||||
print(" - Parses tool calls using XML outputs in the LLM response")
|
||||
|
||||
print("\n2. Standard Function Calling")
|
||||
print(" - Native LLM function calling format")
|
||||
print(" - JSON-based parameter passing")
|
||||
|
||||
use_xml = input("\nSelect tool calling format [1/2] (default: 1): ").strip() != "2"
|
||||
|
||||
print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
|
||||
print("Use Ctrl+C to stop the agent at any time.")
|
||||
|
||||
async def test_agent():
|
||||
thread_manager = ThreadManager()
|
||||
thread_id = await thread_manager.create_thread()
|
||||
logging.info(f"Created new thread: {thread_id}")
|
||||
|
||||
try:
|
||||
result = await run_agent(
|
||||
thread_id=thread_id,
|
||||
max_iterations=5,
|
||||
user_input=project_description,
|
||||
)
|
||||
print("\n✅ Test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
try:
|
||||
asyncio.run(test_agent())
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {str(e)}")
|
||||
raise
|
|
@ -0,0 +1,297 @@
|
|||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.state_manager import StateManager
|
||||
from typing import Optional
|
||||
|
||||
class FilesTool(Tool):
|
||||
"""File management tool for creating, updating, and deleting files.
|
||||
|
||||
This tool provides file operations within a workspace directory, with built-in
|
||||
file filtering and state tracking capabilities.
|
||||
|
||||
Attributes:
|
||||
workspace (str): Path to the workspace directory
|
||||
EXCLUDED_FILES (set): Files to exclude from operations
|
||||
EXCLUDED_DIRS (set): Directories to exclude
|
||||
EXCLUDED_EXT (set): File extensions to exclude
|
||||
SNIPPET_LINES (int): Context lines for edit previews
|
||||
"""
|
||||
|
||||
# Excluded files, directories, and extensions
|
||||
EXCLUDED_FILES = {
|
||||
".DS_Store",
|
||||
".gitignore",
|
||||
"package-lock.json",
|
||||
"postcss.config.js",
|
||||
"postcss.config.mjs",
|
||||
"jsconfig.json",
|
||||
"components.json",
|
||||
"tsconfig.tsbuildinfo",
|
||||
"tsconfig.json",
|
||||
}
|
||||
|
||||
EXCLUDED_DIRS = {
|
||||
"node_modules",
|
||||
".next",
|
||||
"dist",
|
||||
"build",
|
||||
".git"
|
||||
}
|
||||
|
||||
EXCLUDED_EXT = {
|
||||
".ico",
|
||||
".svg",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".webp",
|
||||
".db",
|
||||
".sql"
|
||||
}
|
||||
|
||||
def __init__(self, thread_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
if thread_id:
|
||||
self.state_manager = StateManager(thread_id)
|
||||
asyncio.create_task(self._init_workspace_state())
|
||||
self.SNIPPET_LINES = 4
|
||||
|
||||
def _should_exclude_file(self, rel_path: str) -> bool:
|
||||
"""Check if a file should be excluded based on path, name, or extension"""
|
||||
# Check filename
|
||||
filename = os.path.basename(rel_path)
|
||||
if filename in self.EXCLUDED_FILES:
|
||||
return True
|
||||
|
||||
# Check directory
|
||||
dir_path = os.path.dirname(rel_path)
|
||||
if any(excluded in dir_path for excluded in self.EXCLUDED_DIRS):
|
||||
return True
|
||||
|
||||
# Check extension
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext.lower() in self.EXCLUDED_EXT:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _init_workspace_state(self):
|
||||
"""Initialize or update the workspace state in JSON"""
|
||||
files_state = {}
|
||||
|
||||
# Walk through workspace and record all files
|
||||
for root, _, files in os.walk(self.workspace):
|
||||
for file in files:
|
||||
full_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(full_path, self.workspace)
|
||||
|
||||
# Skip excluded files
|
||||
if self._should_exclude_file(rel_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(full_path, 'r') as f:
|
||||
content = f.read()
|
||||
files_state[rel_path] = {
|
||||
"content": content
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error reading file {rel_path}: {e}")
|
||||
except UnicodeDecodeError:
|
||||
print(f"Skipping binary file: {rel_path}")
|
||||
|
||||
if hasattr(self, 'state_manager'):
|
||||
await self.state_manager.set("files", files_state)
|
||||
|
||||
async def _update_workspace_state(self):
|
||||
"""Update the workspace state after any file operation"""
|
||||
await self._init_workspace_state()
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_file",
|
||||
"description": "Create a new file with the provided contents at a given path in the workspace",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to be created"
|
||||
},
|
||||
"file_contents": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "file_contents"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="create-file",
|
||||
mappings=[
|
||||
{"param_name": "file_path", "node_type": "attribute", "path": "."},
|
||||
{"param_name": "file_contents", "node_type": "content", "path": "."}
|
||||
],
|
||||
example='''
|
||||
<create-file file_path="path/to/file">
|
||||
File contents go here
|
||||
</create-file>
|
||||
'''
|
||||
)
|
||||
async def create_file(self, file_path: str, file_contents: str) -> ToolResult:
|
||||
try:
|
||||
full_path = os.path.join(self.workspace, file_path)
|
||||
if os.path.exists(full_path):
|
||||
return self.fail_response(f"File '{file_path}' already exists. Use update_file to modify existing files.")
|
||||
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
with open(full_path, 'w') as f:
|
||||
f.write(file_contents)
|
||||
|
||||
await self._update_workspace_state()
|
||||
return self.success_response(f"File '{file_path}' created successfully.")
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error creating file: {str(e)}")
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "delete_file",
|
||||
"description": "Delete a file at the given path",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to be deleted"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="delete-file",
|
||||
mappings=[
|
||||
{"param_name": "file_path", "node_type": "attribute", "path": "."}
|
||||
],
|
||||
example='''
|
||||
<delete-file file_path="path/to/file">
|
||||
</delete-file>
|
||||
'''
|
||||
)
|
||||
async def delete_file(self, file_path: str) -> ToolResult:
|
||||
try:
|
||||
full_path = os.path.join(self.workspace, file_path)
|
||||
os.remove(full_path)
|
||||
|
||||
await self._update_workspace_state()
|
||||
return self.success_response(f"File '{file_path}' deleted successfully.")
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error deleting file: {str(e)}")
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the target file"
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "Text to be replaced (must appear exactly once)"
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "Replacement text"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "old_str", "new_str"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="str-replace",
|
||||
mappings=[
|
||||
{"param_name": "file_path", "node_type": "attribute", "path": "file_path"},
|
||||
{"param_name": "old_str", "node_type": "element", "path": "old_str"},
|
||||
{"param_name": "new_str", "node_type": "element", "path": "new_str"}
|
||||
],
|
||||
example='''
|
||||
<str-replace file_path="path/to/file">
|
||||
<old_str>text to replace</old_str>
|
||||
<new_str>replacement text</new_str>
|
||||
</str-replace>
|
||||
'''
|
||||
)
|
||||
async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolResult:
|
||||
try:
|
||||
full_path = Path(os.path.join(self.workspace, file_path))
|
||||
if not full_path.exists():
|
||||
return self.fail_response(f"File '{file_path}' does not exist")
|
||||
|
||||
content = full_path.read_text().expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
|
||||
occurrences = content.count(old_str)
|
||||
if occurrences == 0:
|
||||
return self.fail_response(f"String '{old_str}' not found in file")
|
||||
if occurrences > 1:
|
||||
lines = [i+1 for i, line in enumerate(content.split('\n')) if old_str in line]
|
||||
return self.fail_response(f"Multiple occurrences found in lines {lines}. Please ensure string is unique")
|
||||
|
||||
# Perform replacement
|
||||
new_content = content.replace(old_str, new_str)
|
||||
full_path.write_text(new_content)
|
||||
|
||||
# Update state after file modification
|
||||
await self._update_workspace_state()
|
||||
|
||||
# Show snippet around the edit
|
||||
replacement_line = content.split(old_str)[0].count('\n')
|
||||
start_line = max(0, replacement_line - self.SNIPPET_LINES)
|
||||
end_line = replacement_line + self.SNIPPET_LINES + new_str.count('\n')
|
||||
snippet = '\n'.join(new_content.split('\n')[start_line:end_line + 1])
|
||||
|
||||
return self.success_response(f"Replacement successful. Snippet of changes:\n{snippet}")
|
||||
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error replacing string: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def test_files_tool():
|
||||
files_tool = FilesTool()
|
||||
test_file_path = "test_file.txt"
|
||||
test_content = "This is a test file."
|
||||
updated_content = "This is an updated test file."
|
||||
|
||||
print(f"Using workspace directory: {files_tool.workspace}")
|
||||
|
||||
# Test create_file
|
||||
create_result = await files_tool.create_file(test_file_path, test_content)
|
||||
print("Create file result:", create_result)
|
||||
|
||||
# Test delete_file
|
||||
delete_result = await files_tool.delete_file(test_file_path)
|
||||
print("Delete file result:", delete_result)
|
||||
|
||||
# Test read_file after delete (should fail)
|
||||
read_deleted_result = await files_tool.read_file(test_file_path)
|
||||
print("Read deleted file result:", read_deleted_result)
|
||||
|
||||
asyncio.run(test_files_tool())
|
|
@ -0,0 +1,73 @@
|
|||
import os
|
||||
import asyncio
|
||||
import subprocess
|
||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from typing import Optional
|
||||
|
||||
class TerminalTool(Tool):
|
||||
"""Terminal command execution tool for workspace operations."""
|
||||
|
||||
def __init__(self, thread_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_command",
|
||||
"description": "Execute a shell command in the workspace directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="execute-command",
|
||||
mappings=[
|
||||
{"param_name": "command", "node_type": "content", "path": "."}
|
||||
],
|
||||
example='''
|
||||
<execute-command>
|
||||
npm install package-name
|
||||
</execute-command>
|
||||
'''
|
||||
)
|
||||
async def execute_command(self, command: str) -> ToolResult:
|
||||
original_dir = os.getcwd()
|
||||
try:
|
||||
os.chdir(self.workspace)
|
||||
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self.workspace
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
output = stdout.decode() if stdout else ""
|
||||
error = stderr.decode() if stderr else ""
|
||||
success = process.returncode == 0
|
||||
|
||||
if success:
|
||||
return self.success_response({
|
||||
"output": output,
|
||||
"error": error,
|
||||
"exit_code": process.returncode,
|
||||
"cwd": self.workspace
|
||||
})
|
||||
else:
|
||||
return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
|
||||
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error executing command: {str(e)}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
|
@ -0,0 +1,373 @@
|
|||
:root {
|
||||
--primary-color: #2563eb;
|
||||
--secondary-color: #1e40af;
|
||||
--text-color: #1f2937;
|
||||
--light-text: #6b7280;
|
||||
--background: #ffffff;
|
||||
--section-bg: #f3f4f6;
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html {
|
||||
scroll-behavior: smooth;
|
||||
}
|
||||
|
||||
.scroll-top {
|
||||
position: fixed;
|
||||
bottom: 2rem;
|
||||
right: 2rem;
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
width: 45px;
|
||||
height: 45px;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: all 0.3s ease;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.scroll-top.visible {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
}
|
||||
|
||||
.scroll-top:hover {
|
||||
transform: translateY(-3px);
|
||||
background: var(--secondary-color);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 0 2rem;
|
||||
}
|
||||
|
||||
.header {
|
||||
position: fixed;
|
||||
width: 100%;
|
||||
background: var(--background);
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
z-index: 1000;
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.header.scrolled {
|
||||
transform: translateY(-100%);
|
||||
}
|
||||
|
||||
.header:hover {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.nav {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 1rem 2rem;
|
||||
}
|
||||
|
||||
.logo {
|
||||
font-size: 1.5rem;
|
||||
font-weight: bold;
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.nav-links {
|
||||
display: flex;
|
||||
gap: 2rem;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.nav-links a {
|
||||
text-decoration: none;
|
||||
color: var(--text-color);
|
||||
font-weight: 500;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.nav-links a:hover {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.mobile-nav-toggle {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.hero {
|
||||
padding: 8rem 0 4rem;
|
||||
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
|
||||
background-size: 200% 200%;
|
||||
color: white;
|
||||
text-align: center;
|
||||
animation: gradientMove 10s ease infinite;
|
||||
}
|
||||
|
||||
@keyframes gradientMove {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
|
||||
.hero h1 {
|
||||
font-size: 3rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.hero p {
|
||||
font-size: 1.25rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.cta-button {
|
||||
padding: 1rem 2rem;
|
||||
font-size: 1.1rem;
|
||||
background: white;
|
||||
color: var(--primary-color);
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.cta-button:hover {
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.features {
|
||||
padding: 4rem 0;
|
||||
background: var(--section-bg);
|
||||
}
|
||||
|
||||
.features h2 {
|
||||
text-align: center;
|
||||
margin-bottom: 3rem;
|
||||
}
|
||||
|
||||
.feature-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
|
||||
gap: 2rem;
|
||||
}
|
||||
|
||||
.feature-card {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 10px;
|
||||
text-align: center;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
animation: fadeInUp 0.6s ease forwards;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.feature-card:hover {
|
||||
transform: translateY(-5px);
|
||||
}
|
||||
|
||||
.feature-icon {
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.contact {
|
||||
padding: 4rem 0;
|
||||
}
|
||||
|
||||
.contact h2 {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.contact-form {
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.contact-form input,
|
||||
.contact-form textarea {
|
||||
padding: 0.8rem;
|
||||
border: 2px solid #ddd;
|
||||
border-radius: 5px;
|
||||
font-size: 1rem;
|
||||
transition: all 0.3s ease;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.contact-form input:focus,
|
||||
.contact-form textarea:focus {
|
||||
border-color: var(--primary-color);
|
||||
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||
}
|
||||
|
||||
.contact-form input:hover,
|
||||
.contact-form textarea:hover {
|
||||
border-color: var(--primary-color);
|
||||
}
|
||||
|
||||
.contact-form textarea {
|
||||
height: 150px;
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.submit-button {
|
||||
padding: 1rem;
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
|
||||
.submit-button:hover {
|
||||
background: var(--secondary-color);
|
||||
}
|
||||
|
||||
.testimonials {
|
||||
padding: 4rem 0;
|
||||
background: var(--section-bg);
|
||||
}
|
||||
|
||||
.testimonials h2 {
|
||||
text-align: center;
|
||||
margin-bottom: 3rem;
|
||||
}
|
||||
|
||||
.testimonial-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
||||
gap: 2rem;
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 0 2rem;
|
||||
}
|
||||
|
||||
.testimonial-card {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.testimonial-text {
|
||||
font-style: italic;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.testimonial-author {
|
||||
font-weight: bold;
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.footer {
|
||||
background: var(--text-color);
|
||||
color: white;
|
||||
padding: 2rem 0;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.social-links {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
gap: 1.5rem;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
.social-links a {
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
font-size: 1.5rem;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.social-links a:hover {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.nav-links {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 70px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
background: var(--background);
|
||||
flex-direction: column;
|
||||
padding: 2rem;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
text-align: center;
|
||||
transform: translateY(-100%);
|
||||
opacity: 0;
|
||||
transition: transform 0.3s ease, opacity 0.3s ease;
|
||||
}
|
||||
|
||||
.nav-links.active {
|
||||
display: flex;
|
||||
transform: translateY(0);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.mobile-nav-toggle {
|
||||
display: block;
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.mobile-nav-toggle span {
|
||||
display: block;
|
||||
width: 25px;
|
||||
height: 3px;
|
||||
background: var(--text-color);
|
||||
margin: 5px 0;
|
||||
transition: 0.3s;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.mobile-nav-toggle.active span:nth-child(1) {
|
||||
transform: rotate(45deg) translate(5px, 5px);
|
||||
}
|
||||
|
||||
.mobile-nav-toggle.active span:nth-child(2) {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.mobile-nav-toggle.active span:nth-child(3) {
|
||||
transform: rotate(-45deg) translate(7px, -6px);
|
||||
}
|
||||
|
||||
.hero h1 {
|
||||
font-size: 2.5rem;
|
||||
}
|
||||
|
||||
.feature-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Modern Landing Page</title>
|
||||
<link rel="stylesheet" href="css/styles.css">
|
||||
<style>
|
||||
.loading {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: white;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 9999;
|
||||
}
|
||||
.loading::after {
|
||||
content: '';
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: 4px solid #ddd;
|
||||
border-top-color: var(--primary-color);
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
.loaded {
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: all 0.5s ease;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="loading"></div>
|
||||
<header class="header">
|
||||
<nav class="nav">
|
||||
<div class="logo">Brand</div>
|
||||
<ul class="nav-links">
|
||||
<li><a href="#home">Home</a></li>
|
||||
<li><a href="#features">Features</a></li>
|
||||
<li><a href="#about">About</a></li>
|
||||
<li><a href="#contact">Contact</a></li>
|
||||
</ul>
|
||||
<button class="mobile-nav-toggle" aria-label="Toggle menu">
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</button>
|
||||
</nav>
|
||||
</header>
|
||||
|
||||
<main>
|
||||
<section id="hero" class="hero">
|
||||
<div class="container">
|
||||
<h1>Welcome to the Future</h1>
|
||||
<p>Experience innovation like never before</p>
|
||||
<button class="cta-button">Get Started</button>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section id="features" class="features">
|
||||
<div class="container">
|
||||
<h2>Our Features</h2>
|
||||
<div class="feature-grid">
|
||||
<div class="feature-card">
|
||||
<span class="feature-icon">🚀</span>
|
||||
<h3>Fast Performance</h3>
|
||||
<p>Lightning-quick loading times</p>
|
||||
</div>
|
||||
<div class="feature-card">
|
||||
<span class="feature-icon">🎨</span>
|
||||
<h3>Beautiful Design</h3>
|
||||
<p>Stunning visual experience</p>
|
||||
</div>
|
||||
<div class="feature-card">
|
||||
<span class="feature-icon">📱</span>
|
||||
<h3>Responsive</h3>
|
||||
<p>Works on all devices</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section id="contact" class="contact">
|
||||
<div class="container">
|
||||
<h2>Get in Touch</h2>
|
||||
<form class="contact-form">
|
||||
<input type="text" name="name" placeholder="Name" required>
|
||||
<input type="email" name="email" placeholder="Email" required>
|
||||
<textarea name="message" placeholder="Message" required></textarea>
|
||||
<button type="submit" class="submit-button">Send Message</button>
|
||||
</form>
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
|
||||
<footer class="footer">
|
||||
<div class="container">
|
||||
<p>© 2024 Brand. All rights reserved.</p>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
<button class="scroll-top" aria-label="Scroll to top">↑</button>
|
||||
<script src="js/main.js"></script>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,85 @@
|
|||
const mobileNavToggle = document.querySelector('.mobile-nav-toggle');
|
||||
const navLinks = document.querySelector('.nav-links');
|
||||
|
||||
mobileNavToggle.addEventListener('click', () => {
|
||||
navLinks.classList.toggle('active');
|
||||
mobileNavToggle.classList.toggle('active');
|
||||
});
|
||||
|
||||
document.addEventListener('click', (e) => {
|
||||
if (!e.target.closest('.nav') && navLinks.classList.contains('active')) {
|
||||
navLinks.classList.remove('active');
|
||||
mobileNavToggle.classList.remove('active');
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelectorAll('a[href^="#"]').forEach(anchor => {
|
||||
anchor.addEventListener('click', function (e) {
|
||||
e.preventDefault();
|
||||
document.querySelector(this.getAttribute('href')).scrollIntoView({
|
||||
behavior: 'smooth'
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
const form = document.querySelector('.contact-form');
|
||||
form.addEventListener('submit', (e) => {
|
||||
e.preventDefault();
|
||||
const formData = new FormData(form);
|
||||
const data = Object.fromEntries(formData);
|
||||
console.log('Form submitted:', data);
|
||||
form.reset();
|
||||
});
|
||||
|
||||
let lastScroll = 0;
|
||||
window.addEventListener('load', () => {
|
||||
setTimeout(() => {
|
||||
document.querySelector('.loading').classList.add('loaded');
|
||||
}, 500);
|
||||
});
|
||||
|
||||
const scrollTopBtn = document.querySelector('.scroll-top');
|
||||
scrollTopBtn.addEventListener('click', () => {
|
||||
window.scrollTo({
|
||||
top: 0,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
});
|
||||
|
||||
const observeElements = () => {
|
||||
const observer = new IntersectionObserver((entries) => {
|
||||
entries.forEach(entry => {
|
||||
if (entry.isIntersecting) {
|
||||
entry.target.style.opacity = '1';
|
||||
entry.target.style.transform = 'translateY(0)';
|
||||
}
|
||||
});
|
||||
}, { threshold: 0.1 });
|
||||
|
||||
document.querySelectorAll('.feature-card').forEach(card => {
|
||||
observer.observe(card);
|
||||
card.style.opacity = '0';
|
||||
card.style.transform = 'translateY(20px)';
|
||||
card.style.transition = 'all 0.6s ease';
|
||||
});
|
||||
};
|
||||
|
||||
document.addEventListener('DOMContentLoaded', observeElements);
|
||||
|
||||
window.addEventListener('scroll', () => {
|
||||
scrollTopBtn.classList.toggle('visible', window.scrollY > 500);
|
||||
const header = document.querySelector('.header');
|
||||
const currentScroll = window.pageYOffset;
|
||||
|
||||
if (currentScroll <= 0) {
|
||||
header.classList.remove('scrolled');
|
||||
return;
|
||||
}
|
||||
|
||||
if (currentScroll > lastScroll && !header.classList.contains('scrolled')) {
|
||||
header.classList.add('scrolled');
|
||||
} else if (currentScroll < lastScroll && header.classList.contains('scrolled')) {
|
||||
header.classList.remove('scrolled');
|
||||
}
|
||||
lastScroll = currentScroll;
|
||||
});
|
Loading…
Reference in New Issue