From ca94a759a911c17cb4fe6db393a45c76d4787882 Mon Sep 17 00:00:00 2001 From: Krishav Raj Singh Date: Thu, 2 Oct 2025 15:05:59 +0530 Subject: [PATCH] compress and omit if exceeds context window --- backend/core/agentpress/context_manager.py | 114 ++++++++++++++------- backend/core/agentpress/prompt_caching.py | 15 ++- backend/core/agentpress/thread_manager.py | 10 +- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/backend/core/agentpress/context_manager.py b/backend/core/agentpress/context_manager.py index 5fedb8d1..9f9c7cca 100644 --- a/backend/core/agentpress/context_manager.py +++ b/backend/core/agentpress/context_manager.py @@ -88,9 +88,14 @@ class ContextManager: else: return msg_content - def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]: - """Compress the tool result messages except the most recent one.""" - uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]: + """Compress the tool result messages except the most recent one. + + CRITICAL: Never compresses cached messages to preserve cache hits. + """ + if uncompressed_total_token_count is None: + uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + max_tokens_value = max_tokens or (100 * 1000) if uncompressed_total_token_count > max_tokens_value: @@ -112,9 +117,14 @@ class ContextManager: msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2)) return messages - def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]: - """Compress the user messages except the most recent one.""" - uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]: + """Compress the user messages except the most recent one. + + CRITICAL: Never compresses cached messages to preserve cache hits. + """ + if uncompressed_total_token_count is None: + uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + max_tokens_value = max_tokens or (100 * 1000) if uncompressed_total_token_count > max_tokens_value: @@ -136,9 +146,14 @@ class ContextManager: msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2)) return messages - def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]: - """Compress the assistant messages except the most recent one.""" - uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]: + """Compress the assistant messages except the most recent one. + + CRITICAL: Never compresses cached messages to preserve cache hits. + """ + if uncompressed_total_token_count is None: + uncompressed_total_token_count = token_counter(model=llm_model, messages=messages) + max_tokens_value = max_tokens or (100 * 1000) if uncompressed_total_token_count > max_tokens_value: @@ -188,15 +203,10 @@ class ContextManager: result.append(msg) return result - def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]: - """Compress the messages. + def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5, actual_total_tokens: Optional[int] = None, system_prompt: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """Compress the messages WITHOUT applying caching during iterations. - Args: - messages: List of messages to compress - llm_model: Model name for token counting - max_tokens: Maximum allowed tokens - token_threshold: Token threshold for individual message compression (must be a power of 2) - max_iterations: Maximum number of compression iterations + Caching should be applied ONCE at the end by the caller, not during compression. """ # Get model-specific token limits from constants context_window = model_manager.get_context_window(llm_model) @@ -218,24 +228,46 @@ class ContextManager: result = messages result = self.remove_meta_messages(result) - uncompressed_total_token_count = token_counter(model=llm_model, messages=result) + # Calculate initial token count - just conversation + system prompt, NO caching overhead + print(f"actual_total_tokens: {actual_total_tokens}") + if actual_total_tokens is not None: + uncompressed_total_token_count = actual_total_tokens + else: + print("no actual_total_tokens") + # Count conversation + system prompt WITHOUT caching + if system_prompt: + uncompressed_total_token_count = token_counter(model=llm_model, messages=[system_prompt] + result) + else: + uncompressed_total_token_count = token_counter(model=llm_model, messages=result) + logger.info(f"Initial token count (no caching): {uncompressed_total_token_count}") - result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold) - result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold) - result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold) + # Apply compression + result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count) + result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count) + result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count) - compressed_token_count = token_counter(model=llm_model, messages=result) - - logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_token_count} tokens") + # Recalculate WITHOUT caching overhead + if system_prompt: + compressed_total = token_counter(model=llm_model, messages=[system_prompt] + result) + else: + compressed_total = token_counter(model=llm_model, messages=result) + + logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_total} token") + # Recurse if still too large if max_iterations <= 0: - logger.warning(f"compress_messages: Max iterations reached, omitting messages") - result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens) + logger.warning(f"Max iterations reached, omitting messages") + result = self.compress_messages_by_omitting_messages(result, llm_model, max_tokens, system_prompt=system_prompt) return result - if compressed_token_count > max_tokens: - logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}") - result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1) + if compressed_total > max_tokens: + logger.warning(f"Further compression needed: {compressed_total} > {max_tokens}") + # Recursive call - still NO caching + result = self.compress_messages( + result, llm_model, max_tokens, + token_threshold // 2, max_iterations - 1, + compressed_total, system_prompt, + ) return self.middle_out_messages(result) @@ -245,7 +277,8 @@ class ContextManager: llm_model: str, max_tokens: Optional[int] = 41000, removal_batch_size: int = 10, - min_messages_to_keep: int = 10 + min_messages_to_keep: int = 10, + system_prompt: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Compress the messages by omitting messages from the middle. @@ -263,15 +296,19 @@ class ContextManager: result = self.remove_meta_messages(result) # Early exit if no compression needed - initial_token_count = token_counter(model=llm_model, messages=result) + if system_prompt: + initial_token_count = token_counter(model=llm_model, messages=[system_prompt] + result) + else: + initial_token_count = token_counter(model=llm_model, messages=result) + max_allowed_tokens = max_tokens or (100 * 1000) if initial_token_count <= max_allowed_tokens: return result # Separate system message (assumed to be first) from conversation messages - system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None - conversation_messages = result[1:] if system_message else result + system_message = system_prompt + conversation_messages = result safety_limit = 500 current_token_count = initial_token_count @@ -302,9 +339,14 @@ class ContextManager: messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages current_token_count = token_counter(model=llm_model, messages=messages_to_count) - # Prepare final result - final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages - final_token_count = token_counter(model=llm_model, messages=final_messages) + # Prepare final result - return only conversation messages (matches compress_messages pattern) + final_messages = conversation_messages + + # Log with system prompt included for accurate token reporting + if system_message: + final_token_count = token_counter(model=llm_model, messages=[system_message] + final_messages) + else: + final_token_count = token_counter(model=llm_model, messages=final_messages) logger.info(f"Context compression (omit): {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)") diff --git a/backend/core/agentpress/prompt_caching.py b/backend/core/agentpress/prompt_caching.py index 6710c154..e6422a2b 100644 --- a/backend/core/agentpress/prompt_caching.py +++ b/backend/core/agentpress/prompt_caching.py @@ -232,6 +232,12 @@ def apply_anthropic_caching_strategy( This prevents cache invalidation while optimizing for context window utilization and cost efficiency across different conversation patterns. """ + # DEBUG: Count message roles to verify tool results are included + message_roles = [msg.get('role', 'unknown') for msg in conversation_messages] + role_counts = {} + for role in message_roles: + role_counts[role] = role_counts.get(role, 0) + 1 + logger.debug(f"🔍 CACHING INPUT: {len(conversation_messages)} messages - Roles: {role_counts}") if not conversation_messages: conversation_messages = [] @@ -256,10 +262,15 @@ def apply_anthropic_caching_strategy( # Calculate mathematically optimized cache threshold if cache_threshold_tokens is None: + # Include system prompt tokens in calculation for accurate density (like compression does) + # Use token_counter on combined messages to match compression's calculation method + from litellm import token_counter + total_tokens = token_counter(model=model_name, messages=[working_system_prompt] + conversation_messages) if conversation_messages else 0 + cache_threshold_tokens = calculate_optimal_cache_threshold( context_window_tokens, len(conversation_messages), - get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0 + total_tokens # Now includes system prompt for accurate density calculation ) logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages") @@ -307,6 +318,7 @@ def apply_anthropic_caching_strategy( max_cacheable_tokens = int(context_window_tokens * 0.8) if total_conversation_tokens <= max_cacheable_tokens: + logger.debug(f"Conversation fits within cache limits - use chunked approach") # Conversation fits within cache limits - use chunked approach chunks_created = create_conversation_chunks( conversation_messages, @@ -350,6 +362,7 @@ def create_conversation_chunks( Final messages are NEVER cached to prevent cache invalidation. Returns number of cache blocks created. """ + logger.debug(f"Creating conversation chunks - chunk threshold: {chunk_threshold_tokens}, max blocks: {max_blocks}") if not messages or max_blocks <= 0: return 0 diff --git a/backend/core/agentpress/thread_manager.py b/backend/core/agentpress/thread_manager.py index 74dd42e4..d0ac4d04 100644 --- a/backend/core/agentpress/thread_manager.py +++ b/backend/core/agentpress/thread_manager.py @@ -17,6 +17,7 @@ from langfuse.client import StatefulGenerationClient, StatefulTraceClient from core.services.langfuse import langfuse from datetime import datetime, timezone from core.billing.billing_integration import billing_integration +from litellm.utils import token_counter ToolChoice = Literal["auto", "required", "none"] @@ -305,8 +306,11 @@ class ThreadManager: if ENABLE_CONTEXT_MANAGER: logger.debug(f"Context manager enabled, compressing {len(messages)} messages") context_manager = ContextManager() + compressed_messages = context_manager.compress_messages( - messages, llm_model, max_tokens=llm_max_tokens + messages, llm_model, max_tokens=llm_max_tokens, + actual_total_tokens=None, # Will be calculated inside + system_prompt=system_prompt # KEY FIX: No caching during compression ) logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages") messages = compressed_messages @@ -340,6 +344,10 @@ class ThreadManager: except Exception as e: logger.warning(f"Failed to update Langfuse generation: {e}") + # Log final prepared messages token count + final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages) + logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens") + # Make LLM call try: llm_response = await make_llm_api_call(