Merge pull request #661 from tnfssc/feat/context-summary-compressions

This commit is contained in:
Sharath 2025-06-06 14:59:10 +05:30 committed by GitHub
commit 57b68664c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 132 additions and 20 deletions

2
.gitignore vendored
View File

@ -198,3 +198,5 @@ rabbitmq_data
.setup_progress
.setup_env.json
backend/.test_token_compression.py

View File

@ -25,6 +25,7 @@ from utils.logger import logger
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from services.langfuse import langfuse
import datetime
from litellm import token_counter
# Type alias for tool choice
ToolChoice = Literal["auto", "required", "none"]
@ -75,6 +76,134 @@ class ThreadManager:
pass
return False
def _compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
"""Compress the message content."""
# print("max_length", max_length)
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
else:
return msg_content
def _safe_truncate(self, msg_content: Union[str, dict], max_length: int = 200000) -> Union[str, dict]:
"""Truncate the message content safely."""
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
return json.dumps(msg_content)[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
def _compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (64 * 1000)):
_i = 0 # Count the number of ToolResult messages
for msg in reversed(messages): # Start from the end and work backwards
if self._is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
_i = 0 # Count the number of User messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'user': # Only compress User messages
_i += 1 # Count the number of User messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent User message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
_i = 0 # Count the number of Assistant messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'assistant': # Only compress Assistant messages
_i += 1 # Count the number of Assistant messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent Assistant message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: Optional[int] = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
token_threshold: must be a power of 2
"""
if 'sonnet' in llm_model.lower():
max_tokens = 200 * 1000 - 64000
elif 'gpt' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
elif 'gemini' in llm_model.lower():
max_tokens = 1000 * 1000 - 300000
elif 'deepseek' in llm_model.lower():
max_tokens = 163 * 1000 - 32000
else:
max_tokens = 41 * 1000 - 10000
if max_iterations <= 0:
logger.warning(f"_compress_messages: Max iterations reached, returning uncompressed messages")
return messages
result = messages
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
result = self._compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self._compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self._compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
compressed_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
if (compressed_token_count > max_tokens):
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
return result
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager."""
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
@ -287,7 +416,6 @@ Here are the XML tools available with examples:
# 2. Check token count before proceeding
token_count = 0
try:
from litellm import token_counter
# Use the potentially modified working_system_prompt for token counting
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
token_threshold = self.context_manager.token_threshold
@ -344,25 +472,7 @@ Here are the XML tools available with examples:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
if uncompressed_total_token_count > (llm_max_tokens or (100 * 1000)):
_i = 0 # Count the number of ToolResult messages
for msg in reversed(prepared_messages): # Start from the end and work backwards
if self._is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > 1000: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = msg["content"][:3000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
else:
msg["content"] = msg["content"][:200000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
prepared_messages = self._compress_messages(prepared_messages, llm_model)
# 5. Make LLM API call
logger.debug("Making LLM API call")