suna/backend/agentpress/context_manager.py

336 lines
18 KiB
Python
Raw Normal View History

2025-04-17 07:16:53 +08:00
"""
Context Management for AgentPress Threads.
This module handles token counting and thread summarization to prevent
reaching the context window limitations of LLM models.
"""
import json
2025-07-06 00:07:35 +08:00
from typing import List, Dict, Any, Optional, Union
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
from litellm.utils import token_counter
2025-04-17 07:16:53 +08:00
from services.supabase import DBConnection
from utils.logger import logger
2025-07-06 12:40:44 +08:00
DEFAULT_TOKEN_THRESHOLD = 120000
2025-04-17 07:16:53 +08:00
class ContextManager:
"""Manages thread context including token counting and summarization."""
def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD):
"""Initialize the ContextManager.
Args:
token_threshold: Token count threshold to trigger summarization
"""
self.db = DBConnection()
self.token_threshold = token_threshold
2025-07-06 00:07:35 +08:00
def is_tool_result_message(self, msg: Dict[str, Any]) -> bool:
"""Check if a message is a tool result message."""
2025-07-28 18:53:36 +08:00
if not isinstance(msg, dict) or not ("content" in msg and msg['content']):
2025-07-06 00:07:35 +08:00
return False
content = msg['content']
if isinstance(content, str) and "ToolResult" in content:
return True
if isinstance(content, dict) and "tool_execution" in content:
return True
if isinstance(content, dict) and "interactive_elements" in content:
return True
if isinstance(content, str):
try:
parsed_content = json.loads(content)
if isinstance(parsed_content, dict) and "tool_execution" in parsed_content:
return True
if isinstance(parsed_content, dict) and "interactive_elements" in content:
return True
except (json.JSONDecodeError, TypeError):
pass
return False
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
def compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
"""Compress the message content."""
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
# Special handling for edit_file tool result to preserve JSON structure
tool_execution = msg_content.get("tool_execution", {})
if tool_execution.get("function_name") == "edit_file":
output = tool_execution.get("result", {}).get("output", {})
if isinstance(output, dict):
# Truncate file contents within the JSON
for key in ["original_content", "updated_content"]:
if isinstance(output.get(key), str) and len(output[key]) > max_length // 4:
output[key] = output[key][:max_length // 4] + "\n... (truncated)"
# After potential truncation, check size again
if len(json.dumps(msg_content)) > max_length:
# If still too large, fall back to string truncation
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
else:
return msg_content
2025-07-06 00:07:35 +08:00
else:
return msg_content
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
def safe_truncate(self, msg_content: Union[str, dict], max_length: int = 100000) -> Union[str, dict]:
"""Truncate the message content safely by removing the middle portion."""
max_length = min(max_length, 100000)
if isinstance(msg_content, str):
if len(msg_content) > max_length:
# Calculate how much to keep from start and end
keep_length = max_length - 150 # Reserve space for truncation message
start_length = keep_length // 2
end_length = keep_length - start_length
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
start_part = msg_content[:start_length]
end_part = msg_content[-end_length:] if end_length > 0 else ""
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
2025-04-17 07:16:53 +08:00
else:
2025-07-06 00:07:35 +08:00
return msg_content
elif isinstance(msg_content, dict):
json_str = json.dumps(msg_content)
if len(json_str) > max_length:
# Calculate how much to keep from start and end
keep_length = max_length - 150 # Reserve space for truncation message
start_length = keep_length // 2
end_length = keep_length - start_length
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
start_part = json_str[:start_length]
end_part = json_str[-end_length:] if end_length > 0 else ""
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of ToolResult messages
for msg in reversed(messages): # Start from the end and work backwards
2025-07-28 18:53:36 +08:00
if not isinstance(msg, dict):
continue # Skip non-dict messages
2025-07-06 00:07:35 +08:00
if self.is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of User messages
for msg in reversed(messages): # Start from the end and work backwards
2025-07-28 18:53:36 +08:00
if not isinstance(msg, dict):
continue # Skip non-dict messages
2025-07-06 00:07:35 +08:00
if msg.get('role') == 'user': # Only compress User messages
_i += 1 # Count the number of User messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent User message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
_i = 0 # Count the number of Assistant messages
for msg in reversed(messages): # Start from the end and work backwards
2025-07-28 18:53:36 +08:00
if not isinstance(msg, dict):
continue # Skip non-dict messages
2025-07-06 00:07:35 +08:00
if msg.get('role') == 'assistant': # Only compress Assistant messages
_i += 1 # Count the number of Assistant messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent Assistant message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def remove_meta_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Remove meta messages from the messages."""
result: List[Dict[str, Any]] = []
for msg in messages:
msg_content = msg.get('content')
# Try to parse msg_content as JSON if it's a string
if isinstance(msg_content, str):
try:
msg_content = json.loads(msg_content)
except json.JSONDecodeError:
pass
if isinstance(msg_content, dict):
# Create a copy to avoid modifying the original
msg_content_copy = msg_content.copy()
if "tool_execution" in msg_content_copy:
tool_execution = msg_content_copy["tool_execution"].copy()
if "arguments" in tool_execution:
del tool_execution["arguments"]
msg_content_copy["tool_execution"] = tool_execution
# Create a new message dict with the modified content
new_msg = msg.copy()
new_msg["content"] = json.dumps(msg_content_copy)
result.append(new_msg)
else:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
2025-04-17 07:16:53 +08:00
Args:
2025-07-06 00:07:35 +08:00
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
2025-04-17 07:16:53 +08:00
"""
2025-07-06 00:07:35 +08:00
# Set model-specific token limits
if 'sonnet' in llm_model.lower():
max_tokens = 200 * 1000 - 64000 - 28000
elif 'gpt' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
elif 'gemini' in llm_model.lower():
max_tokens = 1000 * 1000 - 300000
elif 'deepseek' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
else:
max_tokens = 41 * 1000 - 10000
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
result = messages
result = self.remove_meta_messages(result)
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
compressed_token_count = token_counter(model=llm_model, messages=result)
2025-04-17 07:16:53 +08:00
2025-08-17 10:10:56 +08:00
logger.debug(f"compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
return result
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
return self.middle_out_messages(result)
def compress_messages_by_omitting_messages(
self,
messages: List[Dict[str, Any]],
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
2025-04-17 07:16:53 +08:00
Args:
2025-07-06 00:07:35 +08:00
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
removal_batch_size: Number of messages to remove per iteration
min_messages_to_keep: Minimum number of messages to preserve
2025-04-17 07:16:53 +08:00
"""
2025-07-06 00:07:35 +08:00
if not messages:
return messages
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
result = messages
result = self.remove_meta_messages(result)
# Early exit if no compression needed
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
2025-07-28 18:53:36 +08:00
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
2025-07-06 00:07:35 +08:00
conversation_messages = result[1:] if system_message else result
safety_limit = 500
current_token_count = initial_token_count
while current_token_count > max_allowed_tokens and safety_limit > 0:
safety_limit -= 1
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
if len(conversation_messages) <= min_messages_to_keep:
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
break
# Calculate removal strategy based on current message count
if len(conversation_messages) > (removal_batch_size * 2):
# Remove from middle, keeping recent and early context
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
middle_end = middle_start + removal_batch_size
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
2025-04-17 07:16:53 +08:00
else:
2025-07-06 00:07:35 +08:00
# Remove from earlier messages, preserving recent context
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
if messages_to_remove > 0:
conversation_messages = conversation_messages[messages_to_remove:]
else:
# Can't remove any more messages
break
# Recalculate token count
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
final_token_count = token_counter(model=llm_model, messages=final_messages)
2025-08-17 10:10:56 +08:00
logger.debug(f"compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
2025-04-17 07:16:53 +08:00
2025-07-06 00:07:35 +08:00
return final_messages
def middle_out_messages(self, messages: List[Dict[str, Any]], max_messages: int = 320) -> List[Dict[str, Any]]:
"""Remove messages from the middle of the list, keeping max_messages total."""
if len(messages) <= max_messages:
return messages
# Keep half from the beginning and half from the end
keep_start = max_messages // 2
keep_end = max_messages - keep_start
return messages[:keep_start] + messages[-keep_end:]