compress and omit if exceeds context window

This commit is contained in:
Krishav Raj Singh 2025-10-02 15:05:59 +05:30
parent 7ff206157d
commit ca94a759a9
3 changed files with 101 additions and 38 deletions

View File

@ -88,9 +88,14 @@ class ContextManager:
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one.
CRITICAL: Never compresses cached messages to preserve cache hits.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -112,9 +117,14 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one.
CRITICAL: Never compresses cached messages to preserve cache hits.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -136,9 +146,14 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one.
CRITICAL: Never compresses cached messages to preserve cache hits.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -188,15 +203,10 @@ class ContextManager:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5, actual_total_tokens: Optional[int] = None, system_prompt: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Compress the messages WITHOUT applying caching during iterations.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
Caching should be applied ONCE at the end by the caller, not during compression.
"""
# Get model-specific token limits from constants
context_window = model_manager.get_context_window(llm_model)
@ -218,24 +228,46 @@ class ContextManager:
result = messages
result = self.remove_meta_messages(result)
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
# Calculate initial token count - just conversation + system prompt, NO caching overhead
print(f"actual_total_tokens: {actual_total_tokens}")
if actual_total_tokens is not None:
uncompressed_total_token_count = actual_total_tokens
else:
print("no actual_total_tokens")
# Count conversation + system prompt WITHOUT caching
if system_prompt:
uncompressed_total_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Initial token count (no caching): {uncompressed_total_token_count}")
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
# Apply compression
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
compressed_token_count = token_counter(model=llm_model, messages=result)
# Recalculate WITHOUT caching overhead
if system_prompt:
compressed_total = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
compressed_total = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_token_count} tokens")
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_total} token")
# Recurse if still too large
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
logger.warning(f"Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(result, llm_model, max_tokens, system_prompt=system_prompt)
return result
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
if compressed_total > max_tokens:
logger.warning(f"Further compression needed: {compressed_total} > {max_tokens}")
# Recursive call - still NO caching
result = self.compress_messages(
result, llm_model, max_tokens,
token_threshold // 2, max_iterations - 1,
compressed_total, system_prompt,
)
return self.middle_out_messages(result)
@ -245,7 +277,8 @@ class ContextManager:
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
min_messages_to_keep: int = 10,
system_prompt: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
@ -263,15 +296,19 @@ class ContextManager:
result = self.remove_meta_messages(result)
# Early exit if no compression needed
initial_token_count = token_counter(model=llm_model, messages=result)
if system_prompt:
initial_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
conversation_messages = result[1:] if system_message else result
system_message = system_prompt
conversation_messages = result
safety_limit = 500
current_token_count = initial_token_count
@ -302,9 +339,14 @@ class ContextManager:
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
final_token_count = token_counter(model=llm_model, messages=final_messages)
# Prepare final result - return only conversation messages (matches compress_messages pattern)
final_messages = conversation_messages
# Log with system prompt included for accurate token reporting
if system_message:
final_token_count = token_counter(model=llm_model, messages=[system_message] + final_messages)
else:
final_token_count = token_counter(model=llm_model, messages=final_messages)
logger.info(f"Context compression (omit): {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")

View File

@ -232,6 +232,12 @@ def apply_anthropic_caching_strategy(
This prevents cache invalidation while optimizing for context window utilization
and cost efficiency across different conversation patterns.
"""
# DEBUG: Count message roles to verify tool results are included
message_roles = [msg.get('role', 'unknown') for msg in conversation_messages]
role_counts = {}
for role in message_roles:
role_counts[role] = role_counts.get(role, 0) + 1
logger.debug(f"🔍 CACHING INPUT: {len(conversation_messages)} messages - Roles: {role_counts}")
if not conversation_messages:
conversation_messages = []
@ -256,10 +262,15 @@ def apply_anthropic_caching_strategy(
# Calculate mathematically optimized cache threshold
if cache_threshold_tokens is None:
# Include system prompt tokens in calculation for accurate density (like compression does)
# Use token_counter on combined messages to match compression's calculation method
from litellm import token_counter
total_tokens = token_counter(model=model_name, messages=[working_system_prompt] + conversation_messages) if conversation_messages else 0
cache_threshold_tokens = calculate_optimal_cache_threshold(
context_window_tokens,
len(conversation_messages),
get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0
total_tokens # Now includes system prompt for accurate density calculation
)
logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages")
@ -307,6 +318,7 @@ def apply_anthropic_caching_strategy(
max_cacheable_tokens = int(context_window_tokens * 0.8)
if total_conversation_tokens <= max_cacheable_tokens:
logger.debug(f"Conversation fits within cache limits - use chunked approach")
# Conversation fits within cache limits - use chunked approach
chunks_created = create_conversation_chunks(
conversation_messages,
@ -350,6 +362,7 @@ def create_conversation_chunks(
Final messages are NEVER cached to prevent cache invalidation.
Returns number of cache blocks created.
"""
logger.debug(f"Creating conversation chunks - chunk threshold: {chunk_threshold_tokens}, max blocks: {max_blocks}")
if not messages or max_blocks <= 0:
return 0

View File

@ -17,6 +17,7 @@ from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from core.services.langfuse import langfuse
from datetime import datetime, timezone
from core.billing.billing_integration import billing_integration
from litellm.utils import token_counter
ToolChoice = Literal["auto", "required", "none"]
@ -305,8 +306,11 @@ class ThreadManager:
if ENABLE_CONTEXT_MANAGER:
logger.debug(f"Context manager enabled, compressing {len(messages)} messages")
context_manager = ContextManager()
compressed_messages = context_manager.compress_messages(
messages, llm_model, max_tokens=llm_max_tokens
messages, llm_model, max_tokens=llm_max_tokens,
actual_total_tokens=None, # Will be calculated inside
system_prompt=system_prompt # KEY FIX: No caching during compression
)
logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages")
messages = compressed_messages
@ -340,6 +344,10 @@ class ThreadManager:
except Exception as e:
logger.warning(f"Failed to update Langfuse generation: {e}")
# Log final prepared messages token count
final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens")
# Make LLM call
try:
llm_response = await make_llm_api_call(