diff --git a/backend/core/agentpress/thread_manager.py b/backend/core/agentpress/thread_manager.py index 030725a7..b41ce9d3 100644 --- a/backend/core/agentpress/thread_manager.py +++ b/backend/core/agentpress/thread_manager.py @@ -13,7 +13,7 @@ This module provides comprehensive conversation management, including: import json from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable from core.services.llm import make_llm_api_call -from core.utils.llm_cache_utils import apply_cache_to_messages +from core.utils.llm_cache_utils import apply_cache_to_messages, validate_cache_blocks from core.agentpress.tool import Tool from core.agentpress.tool_registry import ToolRegistry from core.agentpress.context_manager import ContextManager @@ -471,14 +471,11 @@ When using the tools: if isinstance(msg, dict) and msg.get('role') == 'user': last_user_index = i - # Add all messages (temporary messages are no longer used) prepared_messages.extend(messages) - # Add partial assistant content for auto-continue context (without saving to DB) if auto_continue_count > 0 and continuous_state.get('accumulated_content'): partial_content = continuous_state.get('accumulated_content', '') - # Create temporary assistant message with just the text content temporary_assistant_message = { "role": "assistant", "content": partial_content @@ -491,7 +488,47 @@ When using the tools: openapi_tool_schemas = self.tool_registry.get_openapi_schemas() logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas") - prepared_messages = apply_cache_to_messages(prepared_messages, llm_model) + if token_count < 80_000: + prepared_messages = apply_cache_to_messages(prepared_messages, llm_model) + prepared_messages = validate_cache_blocks(prepared_messages, llm_model) + else: + logger.warning(f"âš ī¸ Skipping cache formatting due to high token count: {token_count}") + + try: + final_token_count = token_counter(model=llm_model, messages=prepared_messages) + + if final_token_count != token_count: + logger.info(f"📊 Final token count: {final_token_count} (initial was {token_count})") + + from core.ai_models import model_manager + context_window = model_manager.get_context_window(llm_model) + + if context_window >= 200_000: + safe_limit = 168_000 + elif context_window >= 100_000: + safe_limit = context_window - 20_000 + else: + safe_limit = context_window - 10_000 + + if final_token_count > safe_limit: + logger.warning(f"âš ī¸ Token count {final_token_count} exceeds safe limit {safe_limit}, compressing messages...") + prepared_messages = self.context_manager.compress_messages( + prepared_messages, + llm_model, + max_tokens=safe_limit + ) + compressed_token_count = token_counter(model=llm_model, messages=prepared_messages) + logger.info(f"✅ Compressed messages: {final_token_count} → {compressed_token_count} tokens") + + if compressed_token_count > safe_limit: + logger.error(f"❌ Still over limit after compression: {compressed_token_count} > {safe_limit}") + prepared_messages = self.context_manager.compress_messages_by_omitting_messages( + prepared_messages, + llm_model, + max_tokens=safe_limit - 10_000 + ) + except Exception as e: + logger.error(f"Error in token checking/compression: {str(e)}") system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system') if system_count > 1: @@ -508,7 +545,6 @@ When using the tools: prepared_messages = filtered_messages logger.info(f"🔧 Reduced to 1 system message") - # Debug: Log what we're sending logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM") for i, msg in enumerate(prepared_messages): role = msg.get('role', 'unknown') @@ -538,7 +574,7 @@ When using the tools: ) llm_response = await make_llm_api_call( - prepared_messages, # Pass the potentially modified messages + prepared_messages, llm_model, temperature=llm_temperature, max_tokens=llm_max_tokens, diff --git a/backend/core/utils/llm_cache_utils.py b/backend/core/utils/llm_cache_utils.py index cf2ff27d..75379244 100644 --- a/backend/core/utils/llm_cache_utils.py +++ b/backend/core/utils/llm_cache_utils.py @@ -20,7 +20,7 @@ def get_resolved_model_id(model_name: str) -> str: return model_name -def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 3584) -> Dict[str, Any]: +def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 10000) -> Dict[str, Any]: if not message or not isinstance(message, dict): logger.debug(f"Skipping cache format: message is not a dict") return message @@ -37,6 +37,7 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char logger.debug(f"Content is already a list but no cache_control found") return message + # Increased min chars threshold to be more selective about what gets cached if len(str(content)) < min_chars_for_cache: logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}") return message @@ -68,11 +69,17 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str, - max_messages_to_cache: int = 4) -> List[Dict[str, Any]]: + max_messages_to_cache: int = 2) -> List[Dict[str, Any]]: if not messages: return messages resolved_model = get_resolved_model_id(model_name) + model_lower = resolved_model.lower() + + if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']): + logger.debug(f"Model {resolved_model} doesn't need cache_control blocks") + return messages + logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})") formatted_messages = [] @@ -88,23 +95,29 @@ def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str, formatted_messages.append(message) continue - if cache_count < max_messages_to_cache: - logger.debug(f"Processing message {i+1}/{len(messages)} for caching") + total_cached = already_cached_count + cache_count + if total_cached < max_messages_to_cache: + logger.debug(f"Processing message {i+1}/{len(messages)} for caching (total cached: {total_cached})") formatted_message = format_message_with_cache(message, resolved_model) if formatted_message != message: cache_count += 1 - logger.info(f"✅ Cache applied to message {i+1}") + logger.info(f"✅ Cache applied to message {i+1} (total cached: {already_cached_count + cache_count})") formatted_messages.append(formatted_message) else: + logger.debug(f"Skipping cache for message {i+1} - limit reached (total cached: {total_cached})") formatted_messages.append(message) - if cache_count > 0 or already_cached_count > 0: - logger.info(f"đŸŽ¯ Caching status: {cache_count} newly cached, {already_cached_count} already cached for model {resolved_model}") + total_final = cache_count + already_cached_count + if total_final > 0: + logger.info(f"đŸŽ¯ Caching status: {cache_count} newly cached, {already_cached_count} already cached, {total_final} total for model {resolved_model}") else: logger.debug(f"â„šī¸ No messages needed caching for model {resolved_model}") + if total_final > max_messages_to_cache: + logger.warning(f"âš ī¸ Total cached messages ({total_final}) exceeds limit ({max_messages_to_cache})") + return formatted_messages @@ -135,3 +148,46 @@ def needs_cache_probe(model_name: str) -> bool: return any(provider in model_lower for provider in streaming_cache_issues) + +def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]: + resolved_model = get_resolved_model_id(model_name) + model_lower = resolved_model.lower() + + if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']): + return messages + + cache_block_count = 0 + for msg in messages: + content = msg.get('content') + if isinstance(content, list) and content: + if isinstance(content[0], dict) and 'cache_control' in content[0]: + cache_block_count += 1 + + if cache_block_count <= max_blocks: + logger.debug(f"✅ Cache validation passed: {cache_block_count}/{max_blocks} blocks") + return messages + + logger.warning(f"âš ī¸ Cache validation failed: {cache_block_count}/{max_blocks} blocks. Removing excess cache blocks.") + + fixed_messages = [] + blocks_seen = 0 + + for msg in messages: + content = msg.get('content') + if isinstance(content, list) and content: + if isinstance(content[0], dict) and 'cache_control' in content[0]: + blocks_seen += 1 + if blocks_seen > max_blocks: + logger.info(f"🔧 Removing cache_control from message {blocks_seen} (role: {msg.get('role')})") + new_content = [{k: v for k, v in content[0].items() if k != 'cache_control'}] + fixed_messages.append({**msg, 'content': new_content}) + else: + fixed_messages.append(msg) + else: + fixed_messages.append(msg) + else: + fixed_messages.append(msg) + + logger.info(f"✅ Fixed cache blocks: {max_blocks}/{cache_block_count} blocks retained") + return fixed_messages + diff --git a/backend/test_multi_model_caching.py b/backend/test_multi_model_caching.py deleted file mode 100644 index 0519ecba..00000000 --- a/backend/test_multi_model_caching.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file