Merge pull request #1769 from KrishavRajSingh/fix/prompt_caching

compress and omit if exceeds context window
This commit is contained in:
Krishav 2025-10-02 20:52:46 +05:30 committed by GitHub
commit d3190e42d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 104 additions and 38 deletions

View File

@ -88,9 +88,15 @@ class ContextManager:
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -112,9 +118,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -136,9 +148,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -188,15 +206,10 @@ class ContextManager:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5, actual_total_tokens: Optional[int] = None, system_prompt: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Compress the messages WITHOUT applying caching during iterations.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
Caching should be applied ONCE at the end by the caller, not during compression.
"""
# Get model-specific token limits from constants
context_window = model_manager.get_context_window(llm_model)
@ -218,24 +231,46 @@ class ContextManager:
result = messages
result = self.remove_meta_messages(result)
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
# Calculate initial token count - just conversation + system prompt, NO caching overhead
print(f"actual_total_tokens: {actual_total_tokens}")
if actual_total_tokens is not None:
uncompressed_total_token_count = actual_total_tokens
else:
print("no actual_total_tokens")
# Count conversation + system prompt WITHOUT caching
if system_prompt:
uncompressed_total_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Initial token count (no caching): {uncompressed_total_token_count}")
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
# Apply compression
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
compressed_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_token_count} tokens")
# Recalculate WITHOUT caching overhead
if system_prompt:
compressed_total = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
compressed_total = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_total} token")
# Recurse if still too large
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
logger.warning(f"Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(result, llm_model, max_tokens, system_prompt=system_prompt)
return result
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
if compressed_total > max_tokens:
logger.warning(f"Further compression needed: {compressed_total} > {max_tokens}")
# Recursive call - still NO caching
result = self.compress_messages(
result, llm_model, max_tokens,
token_threshold // 2, max_iterations - 1,
compressed_total, system_prompt,
)
return self.middle_out_messages(result)
@ -245,7 +280,8 @@ class ContextManager:
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
min_messages_to_keep: int = 10,
system_prompt: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
@ -263,15 +299,19 @@ class ContextManager:
result = self.remove_meta_messages(result)
# Early exit if no compression needed
initial_token_count = token_counter(model=llm_model, messages=result)
if system_prompt:
initial_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
conversation_messages = result[1:] if system_message else result
system_message = system_prompt
conversation_messages = result
safety_limit = 500
current_token_count = initial_token_count
@ -302,9 +342,14 @@ class ContextManager:
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
final_token_count = token_counter(model=llm_model, messages=final_messages)
# Prepare final result - return only conversation messages (matches compress_messages pattern)
final_messages = conversation_messages
# Log with system prompt included for accurate token reporting
if system_message:
final_token_count = token_counter(model=llm_model, messages=[system_message] + final_messages)
else:
final_token_count = token_counter(model=llm_model, messages=final_messages)
logger.info(f"Context compression (omit): {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")

View File

@ -232,6 +232,12 @@ def apply_anthropic_caching_strategy(
This prevents cache invalidation while optimizing for context window utilization
and cost efficiency across different conversation patterns.
"""
# DEBUG: Count message roles to verify tool results are included
message_roles = [msg.get('role', 'unknown') for msg in conversation_messages]
role_counts = {}
for role in message_roles:
role_counts[role] = role_counts.get(role, 0) + 1
logger.debug(f"🔍 CACHING INPUT: {len(conversation_messages)} messages - Roles: {role_counts}")
if not conversation_messages:
conversation_messages = []
@ -256,10 +262,15 @@ def apply_anthropic_caching_strategy(
# Calculate mathematically optimized cache threshold
if cache_threshold_tokens is None:
# Include system prompt tokens in calculation for accurate density (like compression does)
# Use token_counter on combined messages to match compression's calculation method
from litellm import token_counter
total_tokens = token_counter(model=model_name, messages=[working_system_prompt] + conversation_messages) if conversation_messages else 0
cache_threshold_tokens = calculate_optimal_cache_threshold(
context_window_tokens,
len(conversation_messages),
get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0
total_tokens # Now includes system prompt for accurate density calculation
)
logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages")
@ -307,6 +318,7 @@ def apply_anthropic_caching_strategy(
max_cacheable_tokens = int(context_window_tokens * 0.8)
if total_conversation_tokens <= max_cacheable_tokens:
logger.debug(f"Conversation fits within cache limits - use chunked approach")
# Conversation fits within cache limits - use chunked approach
chunks_created = create_conversation_chunks(
conversation_messages,
@ -350,6 +362,7 @@ def create_conversation_chunks(
Final messages are NEVER cached to prevent cache invalidation.
Returns number of cache blocks created.
"""
logger.debug(f"Creating conversation chunks - chunk threshold: {chunk_threshold_tokens}, max blocks: {max_blocks}")
if not messages or max_blocks <= 0:
return 0

View File

@ -17,6 +17,7 @@ from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from core.services.langfuse import langfuse
from datetime import datetime, timezone
from core.billing.billing_integration import billing_integration
from litellm.utils import token_counter
ToolChoice = Literal["auto", "required", "none"]
@ -305,8 +306,11 @@ class ThreadManager:
if ENABLE_CONTEXT_MANAGER:
logger.debug(f"Context manager enabled, compressing {len(messages)} messages")
context_manager = ContextManager()
compressed_messages = context_manager.compress_messages(
messages, llm_model, max_tokens=llm_max_tokens
messages, llm_model, max_tokens=llm_max_tokens,
actual_total_tokens=None, # Will be calculated inside
system_prompt=system_prompt # KEY FIX: No caching during compression
)
logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages")
messages = compressed_messages
@ -340,6 +344,10 @@ class ThreadManager:
except Exception as e:
logger.warning(f"Failed to update Langfuse generation: {e}")
# Log final prepared messages token count
final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens")
# Make LLM call
try:
llm_response = await make_llm_api_call(