mirror of https://github.com/kortix-ai/suna.git
Merge pull request #770 from tnfssc/fix/context-window-finale
This commit is contained in:
commit
7266113e9d
|
@ -240,7 +240,8 @@ class ThreadManager:
|
||||||
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
||||||
|
|
||||||
if max_iterations <= 0:
|
if max_iterations <= 0:
|
||||||
logger.warning(f"_compress_messages: Max iterations reached")
|
logger.warning(f"_compress_messages: Max iterations reached, omitting messages")
|
||||||
|
result = self._compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if (compressed_token_count > max_tokens):
|
if (compressed_token_count > max_tokens):
|
||||||
|
@ -248,6 +249,78 @@ class ThreadManager:
|
||||||
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
|
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _compress_messages_by_omitting_messages(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
llm_model: str,
|
||||||
|
max_tokens: Optional[int] = 41000,
|
||||||
|
removal_batch_size: int = 10,
|
||||||
|
min_messages_to_keep: int = 10
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the messages by omitting messages from the middle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to compress
|
||||||
|
llm_model: Model name for token counting
|
||||||
|
max_tokens: Maximum allowed tokens
|
||||||
|
removal_batch_size: Number of messages to remove per iteration
|
||||||
|
min_messages_to_keep: Minimum number of messages to preserve
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result = messages
|
||||||
|
result = self._remove_meta_messages(result)
|
||||||
|
|
||||||
|
# Early exit if no compression needed
|
||||||
|
initial_token_count = token_counter(model=llm_model, messages=result)
|
||||||
|
max_allowed_tokens = max_tokens or (100 * 1000)
|
||||||
|
|
||||||
|
if initial_token_count <= max_allowed_tokens:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Separate system message (assumed to be first) from conversation messages
|
||||||
|
system_message = messages[0] if messages and messages[0].get('role') == 'system' else None
|
||||||
|
conversation_messages = result[1:] if system_message else result
|
||||||
|
|
||||||
|
safety_limit = 500
|
||||||
|
current_token_count = initial_token_count
|
||||||
|
|
||||||
|
while current_token_count > max_allowed_tokens and safety_limit > 0:
|
||||||
|
safety_limit -= 1
|
||||||
|
|
||||||
|
if len(conversation_messages) <= min_messages_to_keep:
|
||||||
|
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate removal strategy based on current message count
|
||||||
|
if len(conversation_messages) > (removal_batch_size * 2):
|
||||||
|
# Remove from middle, keeping recent and early context
|
||||||
|
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
|
||||||
|
middle_end = middle_start + removal_batch_size
|
||||||
|
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
|
||||||
|
else:
|
||||||
|
# Remove from earlier messages, preserving recent context
|
||||||
|
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
|
||||||
|
if messages_to_remove > 0:
|
||||||
|
conversation_messages = conversation_messages[messages_to_remove:]
|
||||||
|
else:
|
||||||
|
# Can't remove any more messages
|
||||||
|
break
|
||||||
|
|
||||||
|
# Recalculate token count
|
||||||
|
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
|
||||||
|
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
|
||||||
|
|
||||||
|
# Prepare final result
|
||||||
|
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
|
||||||
|
final_token_count = token_counter(model=llm_model, messages=final_messages)
|
||||||
|
|
||||||
|
logger.info(f"_compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
|
||||||
|
|
||||||
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||||
"""Add a tool to the ThreadManager."""
|
"""Add a tool to the ThreadManager."""
|
||||||
|
|
Loading…
Reference in New Issue