mirror of https://github.com/kortix-ai/suna.git
336 lines
18 KiB
Python
336 lines
18 KiB
Python
"""
|
|
Context Management for AgentPress Threads.
|
|
|
|
This module handles token counting and thread summarization to prevent
|
|
reaching the context window limitations of LLM models.
|
|
"""
|
|
|
|
import json
|
|
from typing import List, Dict, Any, Optional, Union
|
|
|
|
from litellm.utils import token_counter
|
|
from services.supabase import DBConnection
|
|
from utils.logger import logger
|
|
|
|
DEFAULT_TOKEN_THRESHOLD = 120000
|
|
|
|
class ContextManager:
|
|
"""Manages thread context including token counting and summarization."""
|
|
|
|
def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD):
|
|
"""Initialize the ContextManager.
|
|
|
|
Args:
|
|
token_threshold: Token count threshold to trigger summarization
|
|
"""
|
|
self.db = DBConnection()
|
|
self.token_threshold = token_threshold
|
|
|
|
def is_tool_result_message(self, msg: Dict[str, Any]) -> bool:
|
|
"""Check if a message is a tool result message."""
|
|
if not isinstance(msg, dict) or not ("content" in msg and msg['content']):
|
|
return False
|
|
content = msg['content']
|
|
if isinstance(content, str) and "ToolResult" in content:
|
|
return True
|
|
if isinstance(content, dict) and "tool_execution" in content:
|
|
return True
|
|
if isinstance(content, dict) and "interactive_elements" in content:
|
|
return True
|
|
if isinstance(content, str):
|
|
try:
|
|
parsed_content = json.loads(content)
|
|
if isinstance(parsed_content, dict) and "tool_execution" in parsed_content:
|
|
return True
|
|
if isinstance(parsed_content, dict) and "interactive_elements" in content:
|
|
return True
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
return False
|
|
|
|
def compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
|
|
"""Compress the message content."""
|
|
if isinstance(msg_content, str):
|
|
if len(msg_content) > max_length:
|
|
return msg_content[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
|
else:
|
|
return msg_content
|
|
elif isinstance(msg_content, dict):
|
|
if len(json.dumps(msg_content)) > max_length:
|
|
# Special handling for edit_file tool result to preserve JSON structure
|
|
tool_execution = msg_content.get("tool_execution", {})
|
|
if tool_execution.get("function_name") == "edit_file":
|
|
output = tool_execution.get("result", {}).get("output", {})
|
|
if isinstance(output, dict):
|
|
# Truncate file contents within the JSON
|
|
for key in ["original_content", "updated_content"]:
|
|
if isinstance(output.get(key), str) and len(output[key]) > max_length // 4:
|
|
output[key] = output[key][:max_length // 4] + "\n... (truncated)"
|
|
|
|
# After potential truncation, check size again
|
|
if len(json.dumps(msg_content)) > max_length:
|
|
# If still too large, fall back to string truncation
|
|
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
|
else:
|
|
return msg_content
|
|
else:
|
|
return msg_content
|
|
|
|
def safe_truncate(self, msg_content: Union[str, dict], max_length: int = 100000) -> Union[str, dict]:
|
|
"""Truncate the message content safely by removing the middle portion."""
|
|
max_length = min(max_length, 100000)
|
|
if isinstance(msg_content, str):
|
|
if len(msg_content) > max_length:
|
|
# Calculate how much to keep from start and end
|
|
keep_length = max_length - 150 # Reserve space for truncation message
|
|
start_length = keep_length // 2
|
|
end_length = keep_length - start_length
|
|
|
|
start_part = msg_content[:start_length]
|
|
end_part = msg_content[-end_length:] if end_length > 0 else ""
|
|
|
|
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
|
else:
|
|
return msg_content
|
|
elif isinstance(msg_content, dict):
|
|
json_str = json.dumps(msg_content)
|
|
if len(json_str) > max_length:
|
|
# Calculate how much to keep from start and end
|
|
keep_length = max_length - 150 # Reserve space for truncation message
|
|
start_length = keep_length // 2
|
|
end_length = keep_length - start_length
|
|
|
|
start_part = json_str[:start_length]
|
|
end_part = json_str[-end_length:] if end_length > 0 else ""
|
|
|
|
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
|
else:
|
|
return msg_content
|
|
|
|
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
|
"""Compress the tool result messages except the most recent one."""
|
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
max_tokens_value = max_tokens or (100 * 1000)
|
|
|
|
if uncompressed_total_token_count > max_tokens_value:
|
|
_i = 0 # Count the number of ToolResult messages
|
|
for msg in reversed(messages): # Start from the end and work backwards
|
|
if not isinstance(msg, dict):
|
|
continue # Skip non-dict messages
|
|
if self.is_tool_result_message(msg): # Only compress ToolResult messages
|
|
_i += 1 # Count the number of ToolResult messages
|
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
if msg_token_count > token_threshold: # If the message is too long
|
|
if _i > 1: # If this is not the most recent ToolResult message
|
|
message_id = msg.get('message_id') # Get the message_id
|
|
if message_id:
|
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
|
else:
|
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
else:
|
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
|
return messages
|
|
|
|
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
|
"""Compress the user messages except the most recent one."""
|
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
max_tokens_value = max_tokens or (100 * 1000)
|
|
|
|
if uncompressed_total_token_count > max_tokens_value:
|
|
_i = 0 # Count the number of User messages
|
|
for msg in reversed(messages): # Start from the end and work backwards
|
|
if not isinstance(msg, dict):
|
|
continue # Skip non-dict messages
|
|
if msg.get('role') == 'user': # Only compress User messages
|
|
_i += 1 # Count the number of User messages
|
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
if msg_token_count > token_threshold: # If the message is too long
|
|
if _i > 1: # If this is not the most recent User message
|
|
message_id = msg.get('message_id') # Get the message_id
|
|
if message_id:
|
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
|
else:
|
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
else:
|
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
|
return messages
|
|
|
|
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
|
"""Compress the assistant messages except the most recent one."""
|
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
max_tokens_value = max_tokens or (100 * 1000)
|
|
|
|
if uncompressed_total_token_count > max_tokens_value:
|
|
_i = 0 # Count the number of Assistant messages
|
|
for msg in reversed(messages): # Start from the end and work backwards
|
|
if not isinstance(msg, dict):
|
|
continue # Skip non-dict messages
|
|
if msg.get('role') == 'assistant': # Only compress Assistant messages
|
|
_i += 1 # Count the number of Assistant messages
|
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
if msg_token_count > token_threshold: # If the message is too long
|
|
if _i > 1: # If this is not the most recent Assistant message
|
|
message_id = msg.get('message_id') # Get the message_id
|
|
if message_id:
|
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
|
else:
|
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
else:
|
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
|
|
|
return messages
|
|
|
|
def remove_meta_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Remove meta messages from the messages."""
|
|
result: List[Dict[str, Any]] = []
|
|
for msg in messages:
|
|
msg_content = msg.get('content')
|
|
# Try to parse msg_content as JSON if it's a string
|
|
if isinstance(msg_content, str):
|
|
try:
|
|
msg_content = json.loads(msg_content)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
if isinstance(msg_content, dict):
|
|
# Create a copy to avoid modifying the original
|
|
msg_content_copy = msg_content.copy()
|
|
if "tool_execution" in msg_content_copy:
|
|
tool_execution = msg_content_copy["tool_execution"].copy()
|
|
if "arguments" in tool_execution:
|
|
del tool_execution["arguments"]
|
|
msg_content_copy["tool_execution"] = tool_execution
|
|
# Create a new message dict with the modified content
|
|
new_msg = msg.copy()
|
|
new_msg["content"] = json.dumps(msg_content_copy)
|
|
result.append(new_msg)
|
|
else:
|
|
result.append(msg)
|
|
return result
|
|
|
|
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
|
|
"""Compress the messages.
|
|
|
|
Args:
|
|
messages: List of messages to compress
|
|
llm_model: Model name for token counting
|
|
max_tokens: Maximum allowed tokens
|
|
token_threshold: Token threshold for individual message compression (must be a power of 2)
|
|
max_iterations: Maximum number of compression iterations
|
|
"""
|
|
# Set model-specific token limits
|
|
if 'sonnet' in llm_model.lower():
|
|
max_tokens = 200 * 1000 - 64000 - 28000
|
|
elif 'gpt' in llm_model.lower():
|
|
max_tokens = 128 * 1000 - 28000
|
|
elif 'gemini' in llm_model.lower():
|
|
max_tokens = 1000 * 1000 - 300000
|
|
elif 'deepseek' in llm_model.lower():
|
|
max_tokens = 128 * 1000 - 28000
|
|
else:
|
|
max_tokens = 41 * 1000 - 10000
|
|
|
|
result = messages
|
|
result = self.remove_meta_messages(result)
|
|
|
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
|
|
|
|
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
|
|
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
|
|
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
|
|
|
|
compressed_token_count = token_counter(model=llm_model, messages=result)
|
|
|
|
logger.info(f"compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
|
|
|
if max_iterations <= 0:
|
|
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
|
|
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
|
|
return result
|
|
|
|
if compressed_token_count > max_tokens:
|
|
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
|
|
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
|
|
|
|
return self.middle_out_messages(result)
|
|
|
|
def compress_messages_by_omitting_messages(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
llm_model: str,
|
|
max_tokens: Optional[int] = 41000,
|
|
removal_batch_size: int = 10,
|
|
min_messages_to_keep: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""Compress the messages by omitting messages from the middle.
|
|
|
|
Args:
|
|
messages: List of messages to compress
|
|
llm_model: Model name for token counting
|
|
max_tokens: Maximum allowed tokens
|
|
removal_batch_size: Number of messages to remove per iteration
|
|
min_messages_to_keep: Minimum number of messages to preserve
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
result = messages
|
|
result = self.remove_meta_messages(result)
|
|
|
|
# Early exit if no compression needed
|
|
initial_token_count = token_counter(model=llm_model, messages=result)
|
|
max_allowed_tokens = max_tokens or (100 * 1000)
|
|
|
|
if initial_token_count <= max_allowed_tokens:
|
|
return result
|
|
|
|
# Separate system message (assumed to be first) from conversation messages
|
|
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
|
|
conversation_messages = result[1:] if system_message else result
|
|
|
|
safety_limit = 500
|
|
current_token_count = initial_token_count
|
|
|
|
while current_token_count > max_allowed_tokens and safety_limit > 0:
|
|
safety_limit -= 1
|
|
|
|
if len(conversation_messages) <= min_messages_to_keep:
|
|
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
|
|
break
|
|
|
|
# Calculate removal strategy based on current message count
|
|
if len(conversation_messages) > (removal_batch_size * 2):
|
|
# Remove from middle, keeping recent and early context
|
|
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
|
|
middle_end = middle_start + removal_batch_size
|
|
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
|
|
else:
|
|
# Remove from earlier messages, preserving recent context
|
|
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
|
|
if messages_to_remove > 0:
|
|
conversation_messages = conversation_messages[messages_to_remove:]
|
|
else:
|
|
# Can't remove any more messages
|
|
break
|
|
|
|
# Recalculate token count
|
|
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
|
|
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
|
|
|
|
# Prepare final result
|
|
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
|
|
final_token_count = token_counter(model=llm_model, messages=final_messages)
|
|
|
|
logger.info(f"compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
|
|
|
|
return final_messages
|
|
|
|
def middle_out_messages(self, messages: List[Dict[str, Any]], max_messages: int = 320) -> List[Dict[str, Any]]:
|
|
"""Remove messages from the middle of the list, keeping max_messages total."""
|
|
if len(messages) <= max_messages:
|
|
return messages
|
|
|
|
# Keep half from the beginning and half from the end
|
|
keep_start = max_messages // 2
|
|
keep_end = max_messages - keep_start
|
|
|
|
return messages[:keep_start] + messages[-keep_end:] |