2025-03-29 06:33:55 +08:00
|
|
|
"""
|
|
|
|
Conversation thread management system for AgentPress.
|
|
|
|
|
|
|
|
This module provides comprehensive conversation management, including:
|
|
|
|
- Thread creation and persistence
|
|
|
|
- Message handling with support for text and images
|
|
|
|
- Tool registration and execution
|
|
|
|
- LLM interaction with streaming support
|
|
|
|
- Error handling and cleanup
|
|
|
|
"""
|
|
|
|
|
|
|
|
import json
|
|
|
|
import uuid
|
2025-04-04 23:06:49 +08:00
|
|
|
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Tuple, Callable, Literal
|
2025-03-30 14:48:57 +08:00
|
|
|
from services.llm import make_llm_api_call
|
|
|
|
from agentpress.tool import Tool, ToolResult
|
|
|
|
from agentpress.tool_registry import ToolRegistry
|
2025-04-04 23:06:49 +08:00
|
|
|
from agentpress.response_processor import (
|
|
|
|
ResponseProcessor,
|
2025-04-05 18:40:06 +08:00
|
|
|
ProcessorConfig
|
2025-04-04 23:06:49 +08:00
|
|
|
)
|
2025-03-30 14:48:57 +08:00
|
|
|
from services.supabase import DBConnection
|
2025-04-02 02:49:35 +08:00
|
|
|
from utils.logger import logger
|
2025-03-30 14:48:57 +08:00
|
|
|
|
2025-04-04 23:06:49 +08:00
|
|
|
# Type alias for tool choice
|
|
|
|
ToolChoice = Literal["auto", "required", "none"]
|
2025-03-29 06:33:55 +08:00
|
|
|
|
|
|
|
class ThreadManager:
|
|
|
|
"""Manages conversation threads with LLM models and tool execution.
|
|
|
|
|
|
|
|
Provides comprehensive conversation management, handling message threading,
|
|
|
|
tool registration, and LLM interactions with support for both standard and
|
|
|
|
XML-based tool execution patterns.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize ThreadManager."""
|
|
|
|
self.db = DBConnection()
|
|
|
|
self.tool_registry = ToolRegistry()
|
2025-04-04 23:06:49 +08:00
|
|
|
self.response_processor = ResponseProcessor(
|
|
|
|
tool_registry=self.tool_registry,
|
|
|
|
add_message_callback=self.add_message
|
|
|
|
)
|
2025-03-29 06:33:55 +08:00
|
|
|
|
|
|
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
|
|
|
"""Add a tool to the ThreadManager."""
|
|
|
|
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
async def add_message(
|
|
|
|
self,
|
|
|
|
thread_id: str,
|
|
|
|
type: str,
|
|
|
|
content: Union[Dict[str, Any], List[Any], str],
|
|
|
|
is_llm_message: bool = False,
|
|
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
):
|
|
|
|
"""Add a message to the thread in the database.
|
2025-03-29 06:33:55 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
Args:
|
|
|
|
thread_id: The ID of the thread to add the message to.
|
|
|
|
type: The type of the message (e.g., 'text', 'image_url', 'tool_call').
|
|
|
|
content: The content of the message. Can be a dictionary, list, or string.
|
|
|
|
It will be stored as JSONB in the database.
|
|
|
|
is_llm_message: Flag indicating if the message originated from the LLM.
|
|
|
|
Defaults to False (user message).
|
|
|
|
metadata: Optional dictionary for additional message metadata.
|
|
|
|
Defaults to None, stored as an empty JSONB object if None.
|
|
|
|
"""
|
|
|
|
logger.debug(f"Adding message of type '{type}' to thread {thread_id}")
|
|
|
|
client = await self.db.client
|
|
|
|
|
|
|
|
# Prepare data for insertion
|
|
|
|
data_to_insert = {
|
|
|
|
'thread_id': thread_id,
|
|
|
|
'type': type,
|
|
|
|
'content': json.dumps(content) if isinstance(content, (dict, list)) else content,
|
|
|
|
'is_llm_message': is_llm_message,
|
|
|
|
'metadata': json.dumps(metadata or {}), # Ensure metadata is always a JSON object
|
|
|
|
}
|
2025-03-29 06:33:55 +08:00
|
|
|
|
|
|
|
try:
|
2025-04-06 17:10:18 +08:00
|
|
|
result = await client.table('messages').insert(data_to_insert).execute()
|
2025-04-01 09:26:52 +08:00
|
|
|
logger.info(f"Successfully added message to thread {thread_id}")
|
2025-03-29 06:33:55 +08:00
|
|
|
except Exception as e:
|
2025-04-01 09:26:52 +08:00
|
|
|
logger.error(f"Failed to add message to thread {thread_id}: {str(e)}", exc_info=True)
|
|
|
|
raise
|
2025-03-29 06:33:55 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
async def get_messages(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
|
|
"""Get all messages for a thread.
|
2025-03-29 06:33:55 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
Args:
|
|
|
|
thread_id: The ID of the thread to get messages for.
|
2025-04-01 09:26:52 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
Returns:
|
|
|
|
List of message objects.
|
|
|
|
"""
|
|
|
|
logger.debug(f"Getting messages for thread {thread_id}")
|
|
|
|
client = await self.db.client
|
|
|
|
|
|
|
|
try:
|
|
|
|
result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
2025-04-01 09:26:52 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
# Parse the returned data which might be stringified JSON
|
|
|
|
if not result.data:
|
2025-04-01 09:26:52 +08:00
|
|
|
return []
|
2025-04-06 17:10:18 +08:00
|
|
|
|
|
|
|
# Return properly parsed JSON objects
|
|
|
|
messages = []
|
|
|
|
for item in result.data:
|
|
|
|
if isinstance(item, str):
|
|
|
|
try:
|
|
|
|
parsed_item = json.loads(item)
|
|
|
|
messages.append(parsed_item)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
logger.error(f"Failed to parse message: {item}")
|
|
|
|
else:
|
|
|
|
messages.append(item)
|
|
|
|
|
2025-04-01 09:26:52 +08:00
|
|
|
return messages
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to get messages for thread {thread_id}: {str(e)}", exc_info=True)
|
2025-04-06 17:10:18 +08:00
|
|
|
return []
|
2025-03-29 06:33:55 +08:00
|
|
|
|
|
|
|
async def run_thread(
|
|
|
|
self,
|
|
|
|
thread_id: str,
|
2025-04-04 23:06:49 +08:00
|
|
|
system_prompt: Dict[str, Any],
|
|
|
|
stream: bool = True,
|
2025-03-29 06:33:55 +08:00
|
|
|
temporary_message: Optional[Dict[str, Any]] = None,
|
2025-04-04 23:06:49 +08:00
|
|
|
llm_model: str = "gpt-4o",
|
|
|
|
llm_temperature: float = 0,
|
|
|
|
llm_max_tokens: Optional[int] = None,
|
|
|
|
processor_config: Optional[ProcessorConfig] = None,
|
|
|
|
tool_choice: ToolChoice = "auto",
|
2025-03-29 06:33:55 +08:00
|
|
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
2025-04-04 23:06:49 +08:00
|
|
|
"""Run a conversation thread with LLM integration and tool execution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
thread_id: The ID of the thread to run
|
|
|
|
system_prompt: System message to set the assistant's behavior
|
|
|
|
stream: Use streaming API for the LLM response
|
|
|
|
temporary_message: Optional temporary user message for this run only
|
|
|
|
llm_model: The name of the LLM model to use
|
|
|
|
llm_temperature: Temperature parameter for response randomness (0-1)
|
|
|
|
llm_max_tokens: Maximum tokens in the LLM response
|
|
|
|
processor_config: Configuration for the response processor
|
|
|
|
tool_choice: Tool choice preference ("auto", "required", "none")
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An async generator yielding response chunks or error dict
|
|
|
|
"""
|
2025-04-06 17:10:18 +08:00
|
|
|
|
2025-04-01 09:26:52 +08:00
|
|
|
logger.info(f"Starting thread execution for thread {thread_id}")
|
2025-04-04 23:06:49 +08:00
|
|
|
logger.debug(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}")
|
2025-03-29 06:33:55 +08:00
|
|
|
|
|
|
|
try:
|
2025-04-04 23:06:49 +08:00
|
|
|
# 1. Get messages from thread for LLM call
|
2025-04-01 14:41:18 +08:00
|
|
|
messages = await self.get_messages(thread_id)
|
|
|
|
|
2025-04-04 23:06:49 +08:00
|
|
|
# 2. Prepare messages for LLM call + add temporary message if it exists
|
|
|
|
prepared_messages = [system_prompt]
|
|
|
|
|
|
|
|
# Find the last user message index
|
|
|
|
last_user_index = -1
|
|
|
|
for i, msg in enumerate(messages):
|
|
|
|
if msg.get('role') == 'user':
|
|
|
|
last_user_index = i
|
|
|
|
|
|
|
|
# Insert temporary message before the last user message if it exists
|
|
|
|
if temporary_message and last_user_index >= 0:
|
|
|
|
prepared_messages.extend(messages[:last_user_index])
|
2025-04-01 14:41:18 +08:00
|
|
|
prepared_messages.append(temporary_message)
|
2025-04-04 23:06:49 +08:00
|
|
|
prepared_messages.extend(messages[last_user_index:])
|
|
|
|
logger.debug("Added temporary message before the last user message")
|
|
|
|
else:
|
|
|
|
# If no user message or no temporary message, just add all messages
|
|
|
|
prepared_messages.extend(messages)
|
|
|
|
if temporary_message:
|
|
|
|
prepared_messages.append(temporary_message)
|
|
|
|
logger.debug("Added temporary message to the end of prepared messages")
|
2025-04-01 09:26:52 +08:00
|
|
|
|
2025-04-04 23:06:49 +08:00
|
|
|
# 3. Create or use processor config
|
|
|
|
if processor_config is None:
|
|
|
|
processor_config = ProcessorConfig()
|
|
|
|
|
|
|
|
logger.debug(f"Processor config: XML={processor_config.xml_tool_calling}, Native={processor_config.native_tool_calling}, "
|
|
|
|
f"Execute tools={processor_config.execute_tools}, Strategy={processor_config.tool_execution_strategy}")
|
2025-04-01 09:26:52 +08:00
|
|
|
|
2025-04-05 18:40:06 +08:00
|
|
|
# Check if native_tool_calling is enabled and throw an error if it is
|
|
|
|
if processor_config.native_tool_calling:
|
|
|
|
error_message = "Native tool calling is not supported in this version"
|
|
|
|
logger.error(error_message)
|
|
|
|
return {
|
|
|
|
"status": "error",
|
|
|
|
"message": error_message
|
|
|
|
}
|
|
|
|
|
2025-04-04 23:06:49 +08:00
|
|
|
# 4. Prepare tools for LLM call
|
2025-03-29 06:33:55 +08:00
|
|
|
openapi_tool_schemas = None
|
2025-04-04 23:06:49 +08:00
|
|
|
if processor_config.native_tool_calling:
|
2025-03-29 06:33:55 +08:00
|
|
|
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
2025-04-04 23:06:49 +08:00
|
|
|
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
2025-03-29 06:33:55 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
# 5. Track this agent run in the database
|
|
|
|
run_id = str(uuid.uuid4())
|
|
|
|
client = await self.db.client
|
|
|
|
run_data = {
|
|
|
|
'id': run_id,
|
|
|
|
'thread_id': thread_id,
|
|
|
|
'status': 'running',
|
|
|
|
'started_at': 'now()',
|
|
|
|
}
|
|
|
|
await client.table('agent_runs').insert(run_data).execute()
|
|
|
|
logger.debug(f"Created agent run record with ID: {run_id}")
|
|
|
|
|
|
|
|
# 6. Make LLM API call
|
2025-04-01 14:41:18 +08:00
|
|
|
logger.info("Making LLM API call")
|
2025-04-04 23:06:49 +08:00
|
|
|
try:
|
|
|
|
llm_response = await make_llm_api_call(
|
|
|
|
prepared_messages,
|
|
|
|
llm_model,
|
|
|
|
temperature=llm_temperature,
|
|
|
|
max_tokens=llm_max_tokens,
|
|
|
|
tools=openapi_tool_schemas,
|
|
|
|
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
|
|
|
stream=stream
|
|
|
|
)
|
|
|
|
logger.debug("Successfully received LLM API response")
|
|
|
|
except Exception as e:
|
2025-04-06 17:10:18 +08:00
|
|
|
# Update agent_run status to error
|
|
|
|
await client.table('agent_runs').update({
|
|
|
|
'status': 'error',
|
|
|
|
'error': str(e),
|
|
|
|
'completed_at': 'now()'
|
|
|
|
}).eq('id', run_id).execute()
|
|
|
|
|
2025-04-04 23:06:49 +08:00
|
|
|
logger.error(f"Failed to make LLM API call: {str(e)}", exc_info=True)
|
|
|
|
raise
|
2025-03-29 06:33:55 +08:00
|
|
|
|
2025-04-06 17:10:18 +08:00
|
|
|
# 7. Process LLM response using the ResponseProcessor
|
2025-03-29 06:33:55 +08:00
|
|
|
if stream:
|
2025-04-01 09:26:52 +08:00
|
|
|
logger.info("Processing streaming response")
|
2025-04-06 17:10:18 +08:00
|
|
|
response_generator = self.response_processor.process_streaming_response(
|
2025-04-04 23:06:49 +08:00
|
|
|
llm_response=llm_response,
|
|
|
|
thread_id=thread_id,
|
|
|
|
config=processor_config
|
2025-03-29 06:33:55 +08:00
|
|
|
)
|
2025-04-06 17:10:18 +08:00
|
|
|
|
|
|
|
# Wrap the generator to update the agent_run when complete
|
|
|
|
async def wrapped_generator():
|
|
|
|
responses = []
|
|
|
|
try:
|
|
|
|
async for chunk in response_generator:
|
|
|
|
responses.append(chunk)
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
# Update agent_run to completed when done
|
|
|
|
await client.table('agent_runs').update({
|
|
|
|
'status': 'completed',
|
|
|
|
'responses': json.dumps(responses),
|
|
|
|
'completed_at': 'now()'
|
|
|
|
}).eq('id', run_id).execute()
|
|
|
|
logger.debug(f"Updated agent run {run_id} to completed status")
|
|
|
|
except Exception as e:
|
|
|
|
# Update agent_run to error
|
|
|
|
await client.table('agent_runs').update({
|
|
|
|
'status': 'error',
|
|
|
|
'error': str(e),
|
|
|
|
'completed_at': 'now()'
|
|
|
|
}).eq('id', run_id).execute()
|
|
|
|
logger.error(f"Error in streaming response: {str(e)}", exc_info=True)
|
|
|
|
raise
|
|
|
|
|
|
|
|
return wrapped_generator()
|
2025-04-01 14:41:18 +08:00
|
|
|
else:
|
|
|
|
logger.info("Processing non-streaming response")
|
2025-04-06 17:10:18 +08:00
|
|
|
try:
|
|
|
|
response = await self.response_processor.process_non_streaming_response(
|
|
|
|
llm_response=llm_response,
|
|
|
|
thread_id=thread_id,
|
|
|
|
config=processor_config
|
|
|
|
)
|
|
|
|
|
|
|
|
# Update agent_run to completed
|
|
|
|
await client.table('agent_runs').update({
|
|
|
|
'status': 'completed',
|
|
|
|
'responses': json.dumps([response]),
|
|
|
|
'completed_at': 'now()'
|
|
|
|
}).eq('id', run_id).execute()
|
|
|
|
logger.debug(f"Updated agent run {run_id} to completed status")
|
|
|
|
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
|
|
# Update agent_run to error
|
|
|
|
await client.table('agent_runs').update({
|
|
|
|
'status': 'error',
|
|
|
|
'error': str(e),
|
|
|
|
'completed_at': 'now()'
|
|
|
|
}).eq('id', run_id).execute()
|
|
|
|
logger.error(f"Error in non-streaming response: {str(e)}", exc_info=True)
|
|
|
|
raise
|
2025-04-04 23:06:49 +08:00
|
|
|
|
2025-03-29 06:33:55 +08:00
|
|
|
except Exception as e:
|
2025-04-01 09:26:52 +08:00
|
|
|
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
|
2025-03-29 06:33:55 +08:00
|
|
|
return {
|
|
|
|
"status": "error",
|
|
|
|
"message": str(e)
|
2025-04-06 17:10:18 +08:00
|
|
|
}
|