suna/agentpress/thread_manager.py

366 lines
16 KiB
Python

import json
import logging
import asyncio
from typing import List, Dict, Any, Optional, Callable, Type
from sqlalchemy import select
from agentpress.db import Database, Thread
from agentpress.llm import make_llm_api_call
from datetime import datetime, UTC
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
import uuid
class ThreadManager:
def __init__(self, db: Optional[Database] = None):
self.db = db if db is not None else Database()
self.tool_registry = ToolRegistry()
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None):
"""
Add a tool to the ThreadManager.
If function_names is provided, only register those specific functions.
If function_names is None, register all functions from the tool.
"""
self.tool_registry.register_tool(tool_class, function_names)
async def create_thread(self) -> int:
async with self.db.get_async_session() as session:
new_thread = Thread(
messages=json.dumps([])
)
session.add(new_thread)
await session.commit()
await session.refresh(new_thread)
return new_thread.thread_id
async def add_message(self, thread_id: int, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
logging.info(f"Adding message to thread {thread_id} with images: {images}")
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
raise ValueError(f"Thread with id {thread_id} not found")
try:
messages = json.loads(thread.messages)
if message_data['role'] == 'user':
last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
if last_assistant_index is not None:
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] if msg['role'] == 'tool')
if tool_call_count != tool_response_count:
await self.cleanup_incomplete_tool_calls(thread_id)
for key, value in message_data.items():
if isinstance(value, ToolResult):
message_data[key] = str(value)
if images:
if isinstance(message_data['content'], str):
message_data['content'] = [{"type": "text", "text": message_data['content']}]
elif not isinstance(message_data['content'], list):
message_data['content'] = []
for image in images:
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:{image['content_type']};base64,{image['base64']}",
"detail": "high"
}
}
message_data['content'].append(image_content)
messages.append(message_data)
thread.messages = json.dumps(messages)
await session.commit()
logging.info(f"Message added to thread {thread_id}: {message_data}")
except Exception as e:
await session.rollback()
logging.error(f"Failed to add message to thread {thread_id}: {e}")
raise e
async def list_messages(self, thread_id: int, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if not thread:
return []
messages = json.loads(thread.messages)
if only_latest_assistant:
for msg in reversed(messages):
if msg.get('role') == 'assistant':
return [msg]
return []
filtered_messages = messages
if hide_tool_msgs:
filtered_messages = [
{k: v for k, v in msg.items() if k != 'tool_calls'}
for msg in filtered_messages
if msg.get('role') != 'tool'
]
if regular_list:
filtered_messages = [
msg for msg in filtered_messages
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
]
return filtered_messages
async def cleanup_incomplete_tool_calls(self, thread_id: int):
messages = await self.list_messages(thread_id)
last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None)
if last_assistant_message:
tool_calls = last_assistant_message.get('tool_calls', [])
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool']
if len(tool_calls) != len(tool_responses):
failed_tool_results = []
for tool_call in tool_calls[len(tool_responses):]:
failed_tool_result = {
"role": "tool",
"tool_call_id": tool_call['id'],
"name": tool_call['function']['name'],
"content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')"
}
failed_tool_results.append(failed_tool_result)
assistant_index = messages.index(last_assistant_message)
messages[assistant_index+1:assistant_index+1] = failed_tool_results
async with self.db.get_async_session() as session:
thread = await session.get(Thread, thread_id)
if thread:
thread.messages = json.dumps(messages)
await session.commit()
return True
return False
async def run_thread(self, thread_id: int, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = True) -> Dict[str, Any]:
messages = await self.list_messages(thread_id)
prepared_messages = [system_message] + messages
if additional_message:
prepared_messages.append(additional_message)
tools = self.tool_registry.get_all_tool_schemas() if use_tools else None
try:
llm_response = await make_llm_api_call(
prepared_messages,
model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice if use_tools else None,
stream=False
)
except Exception as e:
logging.error(f"Error in API call: {str(e)}")
return {
"status": "error",
"message": str(e),
"run_thread_params": {
"thread_id": thread_id,
"system_message": system_message,
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tool_choice": tool_choice,
"additional_message": additional_message,
"execute_tools_async": execute_tools_async,
"execute_model_tool_calls": execute_model_tool_calls,
"use_tools": use_tools
}
}
if use_tools and execute_model_tool_calls:
await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async)
else:
await self.handle_response_without_tools(thread_id, llm_response)
return {
"llm_response": llm_response,
"run_thread_params": {
"thread_id": thread_id,
"system_message": system_message,
"model_name": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"tool_choice": tool_choice,
"additional_message": additional_message,
"execute_tools_async": execute_tools_async,
"execute_model_tool_calls": execute_model_tool_calls,
"use_tools": use_tools
}
}
async def handle_response_without_tools(self, thread_id: int, response: Any):
response_content = response.choices[0].message['content']
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
async def handle_response_with_tools(self, thread_id: int, response: Any, execute_tools_async: bool):
try:
response_message = response.choices[0].message
tool_calls = response_message.get('tool_calls', [])
assistant_message = self.create_assistant_message_with_tools(response_message)
await self.add_message(thread_id, assistant_message)
available_functions = self.get_available_functions()
if tool_calls:
if execute_tools_async:
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
else:
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
for result in tool_results:
await self.add_message(thread_id, result)
except AttributeError as e:
logging.error(f"AttributeError: {e}")
response_content = response.choices[0].message['content']
await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""})
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
message = {
"role": "assistant",
"content": response_message.get('content') or "",
}
tool_calls = response_message.get('tool_calls')
if tool_calls:
message["tool_calls"] = [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
} for tool_call in tool_calls
]
return message
def get_available_functions(self) -> Dict[str, Callable]:
available_functions = {}
for tool_name, tool_info in self.tool_registry.get_all_tools().items():
tool_instance = tool_info['instance']
for func_name, func in tool_instance.__class__.__dict__.items():
if callable(func) and not func_name.startswith("__"):
available_functions[func_name] = getattr(tool_instance, func_name)
return available_functions
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
async def execute_single_tool(tool_call):
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_call_id = tool_call.id
function_to_call = available_functions.get(function_name)
if function_to_call:
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
else:
logging.warning(f"Function {function_name} not found in available functions")
return None
tool_results = await asyncio.gather(*[execute_single_tool(tool_call) for tool_call in tool_calls])
return [result for result in tool_results if result]
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
tool_results = []
for tool_call in tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_call_id = tool_call.id
function_to_call = available_functions.get(function_name)
if function_to_call:
result = await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
if result:
tool_results.append(result)
else:
logging.warning(f"Function {function_name} not found in available functions")
return tool_results
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
try:
function_response = await function_to_call(**function_args)
except Exception as e:
error_message = f"Error in {function_name}: {str(e)}"
function_response = ToolResult(success=False, output=error_message)
return {
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": str(function_response),
}
async def get_thread(self, thread_id: int) -> Optional[Thread]:
async with self.db.get_async_session() as session:
return await session.get(Thread, thread_id)
if __name__ == "__main__":
import asyncio
from tools.files_tool import FilesTool
async def main():
manager = ThreadManager()
manager.add_tool(FilesTool, ['create_file', 'read_file'])
thread_id = await manager.create_thread()
await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"})
system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."}
model_name = "gpt-4o"
# Test with tools
response_with_tools = await manager.run_thread(
thread_id=thread_id,
system_message=system_message,
model_name=model_name,
temperature=0.7,
max_tokens=150,
tool_choice="auto",
additional_message=None,
execute_tools_async=True,
execute_model_tool_calls=True,
use_tools=True
)
print("Response with tools:", response_with_tools)
# Test without tools
response_without_tools = await manager.run_thread(
thread_id=thread_id,
system_message=system_message,
model_name=model_name,
temperature=0.7,
max_tokens=150,
additional_message={"role": "user", "content": "What's the capital of France?"},
use_tools=False
)
print("Response without tools:", response_without_tools)
# List messages in the thread
messages = await manager.list_messages(thread_id)
print("\nMessages in the thread:")
for msg in messages:
print(f"{msg['role'].capitalize()}: {msg['content']}")
# Run the async main function
asyncio.run(main())