suna/backend/agentpress/context_manager.py

315 lines
16 KiB
Python

"""
Context Management for AgentPress Threads.
This module handles token counting and thread summarization to prevent
reaching the context window limitations of LLM models.
"""
import json
from typing import List, Dict, Any, Optional, Union
from litellm.utils import token_counter
from services.supabase import DBConnection
from utils.logger import logger
DEFAULT_TOKEN_THRESHOLD = 120000
class ContextManager:
"""Manages thread context including token counting and summarization."""
def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD):
"""Initialize the ContextManager.
Args:
token_threshold: Token count threshold to trigger summarization
"""
self.db = DBConnection()
self.token_threshold = token_threshold
def is_tool_result_message(self, msg: Dict[str, Any]) -> bool:
"""Check if a message is a tool result message."""
if not ("content" in msg and msg['content']):
return False
content = msg['content']
if isinstance(content, str) and "ToolResult" in content:
return True
if isinstance(content, dict) and "tool_execution" in content:
return True
if isinstance(content, dict) and "interactive_elements" in content:
return True
if isinstance(content, str):
try:
parsed_content = json.loads(content)
if isinstance(parsed_content, dict) and "tool_execution" in parsed_content:
return True
if isinstance(parsed_content, dict) and "interactive_elements" in content:
return True
except (json.JSONDecodeError, TypeError):
pass
return False
def compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
"""Compress the message content."""
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
else:
return msg_content
def safe_truncate(self, msg_content: Union[str, dict], max_length: int = 100000) -> Union[str, dict]:
"""Truncate the message content safely by removing the middle portion."""
max_length = min(max_length, 100000)
if isinstance(msg_content, str):
if len(msg_content) > max_length:
# Calculate how much to keep from start and end
keep_length = max_length - 150 # Reserve space for truncation message
start_length = keep_length // 2
end_length = keep_length - start_length
start_part = msg_content[:start_length]
end_part = msg_content[-end_length:] if end_length > 0 else ""
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
elif isinstance(msg_content, dict):
json_str = json.dumps(msg_content)
if len(json_str) > max_length:
# Calculate how much to keep from start and end
keep_length = max_length - 150 # Reserve space for truncation message
start_length = keep_length // 2
end_length = keep_length - start_length
start_part = json_str[:start_length]
end_part = json_str[-end_length:] if end_length > 0 else ""
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of ToolResult messages
for msg in reversed(messages): # Start from the end and work backwards
if self.is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of User messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'user': # Only compress User messages
_i += 1 # Count the number of User messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent User message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of Assistant messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'assistant': # Only compress Assistant messages
_i += 1 # Count the number of Assistant messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent Assistant message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def remove_meta_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Remove meta messages from the messages."""
result: List[Dict[str, Any]] = []
for msg in messages:
msg_content = msg.get('content')
# Try to parse msg_content as JSON if it's a string
if isinstance(msg_content, str):
try:
msg_content = json.loads(msg_content)
except json.JSONDecodeError:
pass
if isinstance(msg_content, dict):
# Create a copy to avoid modifying the original
msg_content_copy = msg_content.copy()
if "tool_execution" in msg_content_copy:
tool_execution = msg_content_copy["tool_execution"].copy()
if "arguments" in tool_execution:
del tool_execution["arguments"]
msg_content_copy["tool_execution"] = tool_execution
# Create a new message dict with the modified content
new_msg = msg.copy()
new_msg["content"] = json.dumps(msg_content_copy)
result.append(new_msg)
else:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
"""
# Set model-specific token limits
if 'sonnet' in llm_model.lower():
max_tokens = 200 * 1000 - 64000 - 28000
elif 'gpt' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
elif 'gemini' in llm_model.lower():
max_tokens = 1000 * 1000 - 300000
elif 'deepseek' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
else:
max_tokens = 41 * 1000 - 10000
result = messages
result = self.remove_meta_messages(result)
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
compressed_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
return result
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
return self.middle_out_messages(result)
def compress_messages_by_omitting_messages(
self,
messages: List[Dict[str, Any]],
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
removal_batch_size: Number of messages to remove per iteration
min_messages_to_keep: Minimum number of messages to preserve
"""
if not messages:
return messages
result = messages
result = self.remove_meta_messages(result)
# Early exit if no compression needed
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
system_message = messages[0] if messages and messages[0].get('role') == 'system' else None
conversation_messages = result[1:] if system_message else result
safety_limit = 500
current_token_count = initial_token_count
while current_token_count > max_allowed_tokens and safety_limit > 0:
safety_limit -= 1
if len(conversation_messages) <= min_messages_to_keep:
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
break
# Calculate removal strategy based on current message count
if len(conversation_messages) > (removal_batch_size * 2):
# Remove from middle, keeping recent and early context
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
middle_end = middle_start + removal_batch_size
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
else:
# Remove from earlier messages, preserving recent context
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
if messages_to_remove > 0:
conversation_messages = conversation_messages[messages_to_remove:]
else:
# Can't remove any more messages
break
# Recalculate token count
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
final_token_count = token_counter(model=llm_model, messages=final_messages)
logger.info(f"compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
return final_messages
def middle_out_messages(self, messages: List[Dict[str, Any]], max_messages: int = 320) -> List[Dict[str, Any]]:
"""Remove messages from the middle of the list, keeping max_messages total."""
if len(messages) <= max_messages:
return messages
# Keep half from the beginning and half from the end
keep_start = max_messages // 2
keep_end = max_messages - keep_start
return messages[:keep_start] + messages[-keep_end:]