suna/backend/core/agentpress/thread_manager.py

501 lines
23 KiB
Python

"""
Simplified conversation thread management system for AgentPress.
"""
import json
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast
from core.services.llm import make_llm_api_call, LLMError
from core.agentpress.prompt_caching import apply_anthropic_caching_strategy, validate_cache_blocks
from core.agentpress.tool import Tool
from core.agentpress.tool_registry import ToolRegistry
from core.agentpress.context_manager import ContextManager
from core.agentpress.response_processor import ResponseProcessor, ProcessorConfig
from core.agentpress.error_processor import ErrorProcessor
from core.services.supabase import DBConnection
from core.utils.logger import logger
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from core.services.langfuse import langfuse
from datetime import datetime, timezone
from core.billing.billing_integration import billing_integration
from litellm.utils import token_counter
ToolChoice = Literal["auto", "required", "none"]
class ThreadManager:
"""Manages conversation threads with LLM models and tool execution."""
def __init__(self, trace: Optional[StatefulTraceClient] = None, agent_config: Optional[dict] = None):
self.db = DBConnection()
self.tool_registry = ToolRegistry()
self.trace = trace
if not self.trace:
self.trace = langfuse.trace(name="anonymous:thread_manager")
self.agent_config = agent_config
self.response_processor = ResponseProcessor(
tool_registry=self.tool_registry,
add_message_callback=self.add_message,
trace=self.trace,
agent_config=self.agent_config
)
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager."""
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
async def create_thread(
self,
account_id: Optional[str] = None,
project_id: Optional[str] = None,
is_public: bool = False,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""Create a new thread in the database."""
# logger.debug(f"Creating new thread (account_id: {account_id}, project_id: {project_id})")
client = await self.db.client
thread_data = {'is_public': is_public, 'metadata': metadata or {}}
if account_id:
thread_data['account_id'] = account_id
if project_id:
thread_data['project_id'] = project_id
try:
result = await client.table('threads').insert(thread_data).execute()
if result.data and len(result.data) > 0 and 'thread_id' in result.data[0]:
thread_id = result.data[0]['thread_id']
logger.info(f"Successfully created thread: {thread_id}")
return thread_id
else:
raise Exception("Failed to create thread: no thread_id returned")
except Exception as e:
logger.error(f"Failed to create thread: {str(e)}", exc_info=True)
raise Exception(f"Thread creation failed: {str(e)}")
async def add_message(
self,
thread_id: str,
type: str,
content: Union[Dict[str, Any], List[Any], str],
is_llm_message: bool = False,
metadata: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
agent_version_id: Optional[str] = None
):
"""Add a message to the thread in the database."""
# logger.debug(f"Adding message of type '{type}' to thread {thread_id}")
client = await self.db.client
data_to_insert = {
'thread_id': thread_id,
'type': type,
'content': content,
'is_llm_message': is_llm_message,
'metadata': metadata or {},
}
if agent_id:
data_to_insert['agent_id'] = agent_id
if agent_version_id:
data_to_insert['agent_version_id'] = agent_version_id
try:
result = await client.table('messages').insert(data_to_insert).execute()
# logger.debug(f"Successfully added message to thread {thread_id}")
if result.data and len(result.data) > 0 and 'message_id' in result.data[0]:
saved_message = result.data[0]
# Handle billing for assistant response end messages
if type == "assistant_response_end" and isinstance(content, dict):
await self._handle_billing(thread_id, content, saved_message)
return saved_message
else:
logger.error(f"Insert operation failed for thread {thread_id}")
return None
except Exception as e:
logger.error(f"Failed to add message to thread {thread_id}: {str(e)}", exc_info=True)
raise
async def _handle_billing(self, thread_id: str, content: dict, saved_message: dict):
"""Handle billing for LLM usage."""
try:
usage = content.get("usage", {})
# DEBUG: Log the complete usage object to see what data we have
# logger.info(f"🔍 THREAD MANAGER USAGE: {usage}")
# logger.info(f"🔍 THREAD MANAGER CONTENT: {content}")
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
# Try cache_read_input_tokens first (Anthropic standard), then fallback to prompt_tokens_details.cached_tokens
cache_read_tokens = int(usage.get("cache_read_input_tokens", 0) or 0)
if cache_read_tokens == 0:
cache_read_tokens = int(usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) or 0)
cache_creation_tokens = int(usage.get("cache_creation_input_tokens", 0) or 0)
model = content.get("model")
# DEBUG: Log what we detected
logger.info(f"🔍 CACHE DETECTION: cache_read={cache_read_tokens}, cache_creation={cache_creation_tokens}, prompt={prompt_tokens}")
client = await self.db.client
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
if user_id and (prompt_tokens > 0 or completion_tokens > 0):
if cache_read_tokens > 0:
cache_hit_percentage = (cache_read_tokens / prompt_tokens * 100) if prompt_tokens > 0 else 0
logger.info(f"🎯 CACHE HIT: {cache_read_tokens}/{prompt_tokens} tokens ({cache_hit_percentage:.1f}%)")
elif cache_creation_tokens > 0:
logger.info(f"💾 CACHE WRITE: {cache_creation_tokens} tokens stored for future use")
else:
logger.debug(f"❌ NO CACHE: All {prompt_tokens} tokens processed fresh")
deduct_result = await billing_integration.deduct_usage(
account_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
model=model or "unknown",
message_id=saved_message['message_id'],
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens
)
if deduct_result.get('success'):
logger.info(f"Successfully deducted ${deduct_result.get('cost', 0):.6f}")
else:
logger.error(f"Failed to deduct credits: {deduct_result}")
except Exception as e:
logger.error(f"Error handling billing: {str(e)}", exc_info=True)
async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]:
"""Get all messages for a thread."""
logger.debug(f"Getting messages for thread {thread_id}")
client = await self.db.client
try:
all_messages = []
batch_size = 1000
offset = 0
while True:
result = await client.table('messages').select('message_id, type, content').eq('thread_id', thread_id).eq('is_llm_message', True).order('created_at').range(offset, offset + batch_size - 1).execute()
if not result.data:
break
all_messages.extend(result.data)
if len(result.data) < batch_size:
break
offset += batch_size
if not all_messages:
return []
messages = []
for item in all_messages:
if isinstance(item['content'], str):
try:
parsed_item = json.loads(item['content'])
parsed_item['message_id'] = item['message_id']
messages.append(parsed_item)
except json.JSONDecodeError:
logger.error(f"Failed to parse message: {item['content']}")
else:
content = item['content']
content['message_id'] = item['message_id']
messages.append(content)
return messages
except Exception as e:
logger.error(f"Failed to get messages for thread {thread_id}: {str(e)}", exc_info=True)
return []
async def run_thread(
self,
thread_id: str,
system_prompt: Dict[str, Any],
stream: bool = True,
temporary_message: Optional[Dict[str, Any]] = None,
llm_model: str = "gpt-5",
llm_temperature: float = 0,
llm_max_tokens: Optional[int] = None,
processor_config: Optional[ProcessorConfig] = None,
tool_choice: ToolChoice = "auto",
native_max_auto_continues: int = 25,
max_xml_tool_calls: int = 0,
generation: Optional[StatefulGenerationClient] = None,
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with LLM integration and tool execution."""
logger.debug(f"🚀 Starting thread execution for {thread_id} with model {llm_model}")
# Ensure we have a valid ProcessorConfig object
if processor_config is None:
config = ProcessorConfig()
elif isinstance(processor_config, ProcessorConfig):
config = processor_config
else:
logger.error(f"Invalid processor_config type: {type(processor_config)}, creating default")
config = ProcessorConfig()
if max_xml_tool_calls > 0 and not config.max_xml_tool_calls:
config.max_xml_tool_calls = max_xml_tool_calls
auto_continue_state = {
'count': 0,
'active': True,
'continuous_state': {'accumulated_content': '', 'thread_run_id': None}
}
# Single execution if auto-continue is disabled
if native_max_auto_continues == 0:
result = await self._execute_run(
thread_id, system_prompt, llm_model, llm_temperature, llm_max_tokens,
tool_choice, config, stream,
generation, auto_continue_state, temporary_message
)
# If result is an error dict, convert it to a generator that yields the error
if isinstance(result, dict) and result.get("status") == "error":
return self._create_single_error_generator(result)
return result
# Auto-continue execution
return self._auto_continue_generator(
thread_id, system_prompt, llm_model, llm_temperature, llm_max_tokens,
tool_choice, config, stream,
generation, auto_continue_state, temporary_message,
native_max_auto_continues
)
async def _execute_run(
self, thread_id: str, system_prompt: Dict[str, Any], llm_model: str,
llm_temperature: float, llm_max_tokens: Optional[int], tool_choice: ToolChoice,
config: ProcessorConfig, stream: bool, generation: Optional[StatefulGenerationClient],
auto_continue_state: Dict[str, Any], temporary_message: Optional[Dict[str, Any]] = None
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Execute a single LLM run."""
# CRITICAL: Ensure config is always a ProcessorConfig object
if not isinstance(config, ProcessorConfig):
logger.error(f"ERROR: config is {type(config)}, expected ProcessorConfig. Value: {config}")
config = ProcessorConfig() # Create new instance as fallback
try:
# Get and prepare messages
messages = await self.get_llm_messages(thread_id)
# Handle auto-continue context
if auto_continue_state['count'] > 0 and auto_continue_state['continuous_state'].get('accumulated_content'):
partial_content = auto_continue_state['continuous_state']['accumulated_content']
messages.append({"role": "assistant", "content": partial_content})
# ===== CENTRAL CONFIGURATION =====
ENABLE_CONTEXT_MANAGER = True # Set to False to disable context compression
ENABLE_PROMPT_CACHING = True # Set to False to disable prompt caching
# ==================================
# Apply context compression
if ENABLE_CONTEXT_MANAGER:
logger.debug(f"Context manager enabled, compressing {len(messages)} messages")
context_manager = ContextManager()
compressed_messages = context_manager.compress_messages(
messages, llm_model, max_tokens=llm_max_tokens,
actual_total_tokens=None, # Will be calculated inside
system_prompt=system_prompt # KEY FIX: No caching during compression
)
logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages")
messages = compressed_messages
else:
logger.debug("Context manager disabled, using raw messages")
# Apply caching
if ENABLE_PROMPT_CACHING:
prepared_messages = apply_anthropic_caching_strategy(system_prompt, messages, llm_model)
prepared_messages = validate_cache_blocks(prepared_messages, llm_model)
else:
prepared_messages = [system_prompt] + messages
# Get tool schemas if needed
openapi_tool_schemas = self.tool_registry.get_openapi_schemas() if config.native_tool_calling else None
# Update generation tracking
if generation:
try:
generation.update(
input=prepared_messages,
start_time=datetime.now(timezone.utc),
model=llm_model,
model_parameters={
"max_tokens": llm_max_tokens,
"temperature": llm_temperature,
"tool_choice": tool_choice,
"tools": openapi_tool_schemas,
}
)
except Exception as e:
logger.warning(f"Failed to update Langfuse generation: {e}")
# Log final prepared messages token count
final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens")
# Make LLM call
try:
llm_response = await make_llm_api_call(
prepared_messages, llm_model,
temperature=llm_temperature,
max_tokens=llm_max_tokens,
tools=openapi_tool_schemas,
tool_choice=tool_choice if config.native_tool_calling else "none",
stream=stream
)
except LLMError as e:
return {"type": "status", "status": "error", "message": str(e)}
# Check for error response
if isinstance(llm_response, dict) and llm_response.get("status") == "error":
return llm_response
# Process response - ensure config is ProcessorConfig object
# logger.debug(f"Config type before response processing: {type(config)}")
# if not isinstance(config, ProcessorConfig):
# logger.error(f"Config is not ProcessorConfig! Type: {type(config)}, Value: {config}")
# config = ProcessorConfig() # Fallback
if stream and hasattr(llm_response, '__aiter__'):
return self.response_processor.process_streaming_response(
cast(AsyncGenerator, llm_response), thread_id, prepared_messages,
llm_model, config, True,
auto_continue_state['count'], auto_continue_state['continuous_state'],
generation
)
else:
return self.response_processor.process_non_streaming_response(
llm_response, thread_id, prepared_messages, llm_model, config, generation
)
except Exception as e:
processed_error = ErrorProcessor.process_system_error(e, context={"thread_id": thread_id})
ErrorProcessor.log_error(processed_error)
return processed_error.to_stream_dict()
async def _auto_continue_generator(
self, thread_id: str, system_prompt: Dict[str, Any], llm_model: str,
llm_temperature: float, llm_max_tokens: Optional[int], tool_choice: ToolChoice,
config: ProcessorConfig, stream: bool, generation: Optional[StatefulGenerationClient],
auto_continue_state: Dict[str, Any], temporary_message: Optional[Dict[str, Any]],
native_max_auto_continues: int
) -> AsyncGenerator:
"""Generator that handles auto-continue logic."""
logger.debug(f"Starting auto-continue generator, max: {native_max_auto_continues}")
# logger.debug(f"Config type in auto-continue generator: {type(config)}")
# Ensure config is valid ProcessorConfig
if not isinstance(config, ProcessorConfig):
logger.error(f"Invalid config type in auto-continue: {type(config)}, creating new one")
config = ProcessorConfig()
while auto_continue_state['active'] and auto_continue_state['count'] < native_max_auto_continues:
auto_continue_state['active'] = False # Reset for this iteration
try:
response_gen = await self._execute_run(
thread_id, system_prompt, llm_model, llm_temperature, llm_max_tokens,
tool_choice, config, stream,
generation, auto_continue_state,
temporary_message if auto_continue_state['count'] == 0 else None
)
# Handle error responses
if isinstance(response_gen, dict) and response_gen.get("status") == "error":
yield response_gen
break
# Process streaming response
if hasattr(response_gen, '__aiter__'):
async for chunk in cast(AsyncGenerator, response_gen):
# Check for auto-continue triggers
should_continue = self._check_auto_continue_trigger(
chunk, auto_continue_state, native_max_auto_continues
)
# Skip finish chunks that trigger auto-continue
if should_continue:
if chunk.get('type') == 'finish' and chunk.get('finish_reason') == 'tool_calls':
continue
elif chunk.get('type') == 'status':
try:
content = json.loads(chunk.get('content', '{}'))
if content.get('finish_reason') == 'length':
continue
except (json.JSONDecodeError, TypeError):
pass
yield chunk
else:
yield response_gen
if not auto_continue_state['active']:
break
except Exception as e:
if "AnthropicException - Overloaded" in str(e):
logger.error(f"Anthropic overloaded, falling back to OpenRouter")
llm_model = f"openrouter/{llm_model.replace('-20250514', '')}"
auto_continue_state['active'] = True
continue
else:
processed_error = ErrorProcessor.process_system_error(e, context={"thread_id": thread_id})
ErrorProcessor.log_error(processed_error)
yield processed_error.to_stream_dict()
return
# Handle max iterations reached
if auto_continue_state['active'] and auto_continue_state['count'] >= native_max_auto_continues:
logger.warning(f"Reached maximum auto-continue limit ({native_max_auto_continues})")
yield {
"type": "content",
"content": f"\n[Agent reached maximum auto-continue limit of {native_max_auto_continues}]"
}
def _check_auto_continue_trigger(
self, chunk: Dict[str, Any], auto_continue_state: Dict[str, Any],
native_max_auto_continues: int
) -> bool:
"""Check if a response chunk should trigger auto-continue."""
if chunk.get('type') == 'finish':
if chunk.get('finish_reason') == 'tool_calls':
if native_max_auto_continues > 0:
logger.debug(f"Auto-continuing for tool_calls ({auto_continue_state['count'] + 1}/{native_max_auto_continues})")
auto_continue_state['active'] = True
auto_continue_state['count'] += 1
return True
elif chunk.get('finish_reason') == 'xml_tool_limit_reached':
logger.debug("Stopping auto-continue due to XML tool limit")
auto_continue_state['active'] = False
elif chunk.get('type') == 'status':
try:
content = json.loads(chunk.get('content', '{}'))
if content.get('finish_reason') == 'length':
logger.debug(f"Auto-continuing for length limit ({auto_continue_state['count'] + 1}/{native_max_auto_continues})")
auto_continue_state['active'] = True
auto_continue_state['count'] += 1
return True
except (json.JSONDecodeError, TypeError):
pass
return False
async def _create_single_error_generator(self, error_dict: Dict[str, Any]):
"""Create an async generator that yields a single error message."""
yield error_dict