mirror of https://github.com/kortix-ai/suna.git
Merge pull request #661 from tnfssc/feat/context-summary-compressions
This commit is contained in:
commit
57b68664c0
|
@ -198,3 +198,5 @@ rabbitmq_data
|
|||
.setup_progress
|
||||
|
||||
.setup_env.json
|
||||
|
||||
backend/.test_token_compression.py
|
||||
|
|
|
@ -25,6 +25,7 @@ from utils.logger import logger
|
|||
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
||||
from services.langfuse import langfuse
|
||||
import datetime
|
||||
from litellm import token_counter
|
||||
|
||||
# Type alias for tool choice
|
||||
ToolChoice = Literal["auto", "required", "none"]
|
||||
|
@ -74,6 +75,134 @@ class ThreadManager:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return False
|
||||
|
||||
def _compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
|
||||
"""Compress the message content."""
|
||||
# print("max_length", max_length)
|
||||
if isinstance(msg_content, str):
|
||||
if len(msg_content) > max_length:
|
||||
return msg_content[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
|
||||
else:
|
||||
return msg_content
|
||||
elif isinstance(msg_content, dict):
|
||||
if len(json.dumps(msg_content)) > max_length:
|
||||
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
|
||||
else:
|
||||
return msg_content
|
||||
|
||||
def _safe_truncate(self, msg_content: Union[str, dict], max_length: int = 200000) -> Union[str, dict]:
|
||||
"""Truncate the message content safely."""
|
||||
if isinstance(msg_content, str):
|
||||
if len(msg_content) > max_length:
|
||||
return msg_content[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||
else:
|
||||
return msg_content
|
||||
elif isinstance(msg_content, dict):
|
||||
if len(json.dumps(msg_content)) > max_length:
|
||||
return json.dumps(msg_content)[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||
else:
|
||||
return msg_content
|
||||
|
||||
def _compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the tool result messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
if uncompressed_total_token_count > (max_tokens or (64 * 1000)):
|
||||
_i = 0 # Count the number of ToolResult messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if self._is_tool_result_message(msg): # Only compress ToolResult messages
|
||||
_i += 1 # Count the number of ToolResult messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent ToolResult message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
return messages
|
||||
|
||||
def _compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the user messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of User messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if msg.get('role') == 'user': # Only compress User messages
|
||||
_i += 1 # Count the number of User messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent User message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
return messages
|
||||
|
||||
def _compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the assistant messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of Assistant messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if msg.get('role') == 'assistant': # Only compress Assistant messages
|
||||
_i += 1 # Count the number of Assistant messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent Assistant message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
|
||||
return messages
|
||||
|
||||
def _compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: Optional[int] = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Compress the messages.
|
||||
token_threshold: must be a power of 2
|
||||
"""
|
||||
|
||||
if 'sonnet' in llm_model.lower():
|
||||
max_tokens = 200 * 1000 - 64000
|
||||
elif 'gpt' in llm_model.lower():
|
||||
max_tokens = 128 * 1000 - 28000
|
||||
elif 'gemini' in llm_model.lower():
|
||||
max_tokens = 1000 * 1000 - 300000
|
||||
elif 'deepseek' in llm_model.lower():
|
||||
max_tokens = 163 * 1000 - 32000
|
||||
else:
|
||||
max_tokens = 41 * 1000 - 10000
|
||||
|
||||
if max_iterations <= 0:
|
||||
logger.warning(f"_compress_messages: Max iterations reached, returning uncompressed messages")
|
||||
return messages
|
||||
|
||||
result = messages
|
||||
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
result = self._compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
|
||||
result = self._compress_user_messages(result, llm_model, max_tokens, token_threshold)
|
||||
result = self._compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
|
||||
|
||||
compressed_token_count = token_counter(model=llm_model, messages=result)
|
||||
|
||||
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
||||
|
||||
if (compressed_token_count > max_tokens):
|
||||
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
|
||||
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
|
||||
|
||||
return result
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
"""Add a tool to the ThreadManager."""
|
||||
|
@ -287,7 +416,6 @@ Here are the XML tools available with examples:
|
|||
# 2. Check token count before proceeding
|
||||
token_count = 0
|
||||
try:
|
||||
from litellm import token_counter
|
||||
# Use the potentially modified working_system_prompt for token counting
|
||||
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
token_threshold = self.context_manager.token_threshold
|
||||
|
@ -344,25 +472,7 @@ Here are the XML tools available with examples:
|
|||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
|
||||
if uncompressed_total_token_count > (llm_max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of ToolResult messages
|
||||
for msg in reversed(prepared_messages): # Start from the end and work backwards
|
||||
if self._is_tool_result_message(msg): # Only compress ToolResult messages
|
||||
_i += 1 # Count the number of ToolResult messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > 1000: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent ToolResult message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = msg["content"][:3000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
|
||||
else:
|
||||
msg["content"] = msg["content"][:200000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
|
||||
|
||||
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
|
||||
prepared_messages = self._compress_messages(prepared_messages, llm_model)
|
||||
|
||||
# 5. Make LLM API call
|
||||
logger.debug("Making LLM API call")
|
||||
|
|
Loading…
Reference in New Issue