suna/backend/agentpress/thread_manager.py

357 lines
14 KiB
Python
Raw Normal View History

"""
Conversation thread management system for AgentPress.
This module provides comprehensive conversation management, including:
- Thread creation and persistence
- Message handling with support for text and images
- Tool registration and execution
- LLM interaction with streaming support
- Error handling and cleanup
"""
import json
import logging
import asyncio
import uuid
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator
2025-03-30 14:48:57 +08:00
from services.llm import make_llm_api_call
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
from agentpress.processor.llm_response_processor import LLMResponseProcessor
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from services.supabase import DBConnection
from agentpress.processor.xml.xml_tool_parser import XMLToolParser
from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
class ThreadManager:
"""Manages conversation threads with LLM models and tool execution.
Provides comprehensive conversation management, handling message threading,
tool registration, and LLM interactions with support for both standard and
XML-based tool execution patterns.
"""
def __init__(self):
"""Initialize ThreadManager."""
self.db = DBConnection()
self.tool_registry = ToolRegistry()
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager."""
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
async def create_thread(self) -> str:
2025-03-30 14:48:57 +08:00
"""Create a new conversation thread.
Returns:
str: The ID of the newly created thread
"""
thread_id = str(uuid.uuid4())
client = await self.db.client
2025-03-30 14:48:57 +08:00
thread_data = {
'thread_id': thread_id,
'messages': json.dumps([])
2025-03-30 14:48:57 +08:00
}
await client.table('threads').insert(thread_data).execute()
return thread_id
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
"""Add a message to an existing thread."""
logging.info(f"Adding message to thread {thread_id} with images: {images}")
try:
# Handle cleanup of incomplete tool calls
if message_data['role'] == 'user':
messages = await self.get_messages(thread_id)
last_assistant_index = next((i for i in reversed(range(len(messages)))
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
if last_assistant_index is not None:
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
if msg['role'] == 'tool')
if tool_call_count != tool_response_count:
await self.cleanup_incomplete_tool_calls(thread_id)
# Convert ToolResult instances to strings
for key, value in message_data.items():
if isinstance(value, ToolResult):
message_data[key] = str(value)
# Handle image attachments
if images:
if isinstance(message_data['content'], str):
message_data['content'] = [{"type": "text", "text": message_data['content']}]
elif not isinstance(message_data['content'], list):
message_data['content'] = []
for image in images:
image_content = {
"type": "image_url",
"image_url": {
"url": f"data:{image['content_type']};base64,{image['base64']}",
"detail": "high"
}
}
message_data['content'].append(image_content)
# Get current messages
client = await self.db.client
thread = await client.table('threads').select('*').eq('thread_id', thread_id).single().execute()
if not thread.data:
raise ValueError(f"Thread {thread_id} not found")
messages = json.loads(thread.data['messages'])
messages.append(message_data)
# Update thread
await client.table('threads').update({
'messages': json.dumps(messages)
}).eq('thread_id', thread_id).execute()
logging.info(f"Message added to thread {thread_id}: {message_data}")
except Exception as e:
logging.error(f"Failed to add message to thread {thread_id}: {e}")
raise e
async def get_messages(
self,
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
regular_list: bool = True
) -> List[Dict[str, Any]]:
"""Retrieve messages from a thread with optional filtering."""
client = await self.db.client
thread = await client.table('threads').select('*').eq('thread_id', thread_id).single().execute()
if not thread.data:
return []
messages = json.loads(thread.data['messages'])
if only_latest_assistant:
for msg in reversed(messages):
if msg.get('role') == 'assistant':
return [msg]
return []
if hide_tool_msgs:
messages = [
{k: v for k, v in msg.items() if k != 'tool_calls'}
for msg in messages
if msg.get('role') != 'tool'
]
if regular_list:
messages = [
msg for msg in messages
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
]
return messages
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
"""Update an existing message in the thread."""
client = await self.db.client
thread = await client.table('threads').select('*').eq('thread_id', thread_id).single().execute()
if not thread.data:
return
messages = json.loads(thread.data['messages'])
# Find and update the last assistant message
for i in reversed(range(len(messages))):
if messages[i].get('role') == 'assistant':
messages[i] = message
break
await client.table('threads').update({
'messages': json.dumps(messages)
}).eq('thread_id', thread_id).execute()
async def cleanup_incomplete_tool_calls(self, thread_id: str):
"""Clean up incomplete tool calls in a thread."""
messages = await self.get_messages(thread_id)
last_assistant_message = next((m for m in reversed(messages)
if m['role'] == 'assistant' and 'tool_calls' in m), None)
if last_assistant_message:
tool_calls = last_assistant_message.get('tool_calls', [])
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:]
if m['role'] == 'tool']
if len(tool_calls) != len(tool_responses):
failed_tool_results = []
for tool_call in tool_calls[len(tool_responses):]:
failed_tool_result = {
"role": "tool",
"tool_call_id": tool_call['id'],
"name": tool_call['function']['name'],
"content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')"
}
failed_tool_results.append(failed_tool_result)
assistant_index = messages.index(last_assistant_message)
messages[assistant_index+1:assistant_index+1] = failed_tool_results
client = await self.db.client
await client.table('threads').update({
'messages': json.dumps(messages)
}).eq('thread_id', thread_id).execute()
return True
return False
async def run_thread(
self,
thread_id: str,
system_message: Dict[str, Any],
model_name: str,
temperature: float = 0,
max_tokens: Optional[int] = None,
tool_choice: str = "auto",
temporary_message: Optional[Dict[str, Any]] = None,
native_tool_calling: bool = False,
xml_tool_calling: bool = False,
execute_tools: bool = True,
stream: bool = False,
execute_tools_on_stream: bool = False,
parallel_tool_execution: bool = False,
tool_parser: Optional[ToolParserBase] = None,
tool_executor: Optional[ToolExecutorBase] = None,
results_adder: Optional[ResultsAdderBase] = None
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with specified parameters.
Args:
thread_id: ID of the thread to run
system_message: System message for the conversation
model_name: Name of the LLM model to use
temperature: Model temperature (0-1)
max_tokens: Maximum tokens in response
tool_choice: Tool selection strategy ("auto" or "none")
temporary_message: Optional message to include temporarily
native_tool_calling: Whether to use native LLM function calling
xml_tool_calling: Whether to use XML-based tool calling
execute_tools: Whether to execute tool calls
stream: Whether to stream the response
execute_tools_on_stream: Whether to execute tools during streaming
parallel_tool_execution: Whether to execute tools in parallel
tool_parser: Custom tool parser implementation
tool_executor: Custom tool executor implementation
results_adder: Custom results adder implementation
Returns:
Union[Dict[str, Any], AsyncGenerator]: Response or stream
Raises:
ValueError: If incompatible tool calling options are specified
Exception: For other execution failures
Notes:
- Cannot use both native and XML tool calling simultaneously
- Streaming responses include both content and tool results
"""
# Validate tool calling configuration
if native_tool_calling and xml_tool_calling:
raise ValueError("Cannot use both native LLM tool calling and XML tool calling simultaneously")
# Initialize tool components if any tool calling is enabled
if native_tool_calling or xml_tool_calling:
if tool_parser is None:
tool_parser = XMLToolParser(tool_registry=self.tool_registry) if xml_tool_calling else StandardToolParser()
if tool_executor is None:
tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry) if xml_tool_calling else StandardToolExecutor(parallel=parallel_tool_execution)
if results_adder is None:
results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self)
try:
messages = await self.get_messages(thread_id)
prepared_messages = [system_message] + messages
if temporary_message:
prepared_messages.append(temporary_message)
openapi_tool_schemas = None
if native_tool_calling:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
available_functions = self.tool_registry.get_available_functions()
elif xml_tool_calling:
available_functions = self.tool_registry.get_available_functions()
else:
available_functions = {}
response_processor = LLMResponseProcessor(
thread_id=thread_id,
available_functions=available_functions,
add_message_callback=self.add_message,
update_message_callback=self._update_message,
get_messages_callback=self.get_messages,
parallel_tool_execution=parallel_tool_execution,
tool_parser=tool_parser,
tool_executor=tool_executor,
results_adder=results_adder
)
llm_response = await self._run_thread_completion(
messages=prepared_messages,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=openapi_tool_schemas,
tool_choice=tool_choice if native_tool_calling else None,
stream=stream
)
if stream:
return response_processor.process_stream(
response_stream=llm_response,
execute_tools=execute_tools,
execute_tools_on_stream=execute_tools_on_stream
)
await response_processor.process_response(
response=llm_response,
execute_tools=execute_tools
)
return llm_response
except Exception as e:
logging.error(f"Error in run_thread: {str(e)}")
return {
"status": "error",
"message": str(e)
}
async def _run_thread_completion(
self,
messages: List[Dict[str, Any]],
model_name: str,
temperature: float,
max_tokens: Optional[int],
tools: Optional[List[Dict[str, Any]]],
tool_choice: Optional[str],
stream: bool
) -> Union[Any, AsyncGenerator]:
"""Get completion from LLM API."""
return await make_llm_api_call(
messages,
model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice,
stream=stream
)