fix: agent stopping in middle of task

This commit is contained in:
Saumya 2025-09-16 01:23:31 +05:30
parent c5a155554d
commit ca84e76c56
3 changed files with 106 additions and 15 deletions

View File

@ -13,7 +13,7 @@ This module provides comprehensive conversation management, including:
import json
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable
from core.services.llm import make_llm_api_call
from core.utils.llm_cache_utils import apply_cache_to_messages
from core.utils.llm_cache_utils import apply_cache_to_messages, validate_cache_blocks
from core.agentpress.tool import Tool
from core.agentpress.tool_registry import ToolRegistry
from core.agentpress.context_manager import ContextManager
@ -471,14 +471,11 @@ When using the tools:
if isinstance(msg, dict) and msg.get('role') == 'user':
last_user_index = i
# Add all messages (temporary messages are no longer used)
prepared_messages.extend(messages)
# Add partial assistant content for auto-continue context (without saving to DB)
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
partial_content = continuous_state.get('accumulated_content', '')
# Create temporary assistant message with just the text content
temporary_assistant_message = {
"role": "assistant",
"content": partial_content
@ -491,7 +488,47 @@ When using the tools:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
if token_count < 80_000:
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
prepared_messages = validate_cache_blocks(prepared_messages, llm_model)
else:
logger.warning(f"⚠️ Skipping cache formatting due to high token count: {token_count}")
try:
final_token_count = token_counter(model=llm_model, messages=prepared_messages)
if final_token_count != token_count:
logger.info(f"📊 Final token count: {final_token_count} (initial was {token_count})")
from core.ai_models import model_manager
context_window = model_manager.get_context_window(llm_model)
if context_window >= 200_000:
safe_limit = 168_000
elif context_window >= 100_000:
safe_limit = context_window - 20_000
else:
safe_limit = context_window - 10_000
if final_token_count > safe_limit:
logger.warning(f"⚠️ Token count {final_token_count} exceeds safe limit {safe_limit}, compressing messages...")
prepared_messages = self.context_manager.compress_messages(
prepared_messages,
llm_model,
max_tokens=safe_limit
)
compressed_token_count = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"✅ Compressed messages: {final_token_count}{compressed_token_count} tokens")
if compressed_token_count > safe_limit:
logger.error(f"❌ Still over limit after compression: {compressed_token_count} > {safe_limit}")
prepared_messages = self.context_manager.compress_messages_by_omitting_messages(
prepared_messages,
llm_model,
max_tokens=safe_limit - 10_000
)
except Exception as e:
logger.error(f"Error in token checking/compression: {str(e)}")
system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system')
if system_count > 1:
@ -508,7 +545,6 @@ When using the tools:
prepared_messages = filtered_messages
logger.info(f"🔧 Reduced to 1 system message")
# Debug: Log what we're sending
logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM")
for i, msg in enumerate(prepared_messages):
role = msg.get('role', 'unknown')
@ -538,7 +574,7 @@ When using the tools:
)
llm_response = await make_llm_api_call(
prepared_messages, # Pass the potentially modified messages
prepared_messages,
llm_model,
temperature=llm_temperature,
max_tokens=llm_max_tokens,

View File

@ -20,7 +20,7 @@ def get_resolved_model_id(model_name: str) -> str:
return model_name
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 3584) -> Dict[str, Any]:
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 10000) -> Dict[str, Any]:
if not message or not isinstance(message, dict):
logger.debug(f"Skipping cache format: message is not a dict")
return message
@ -37,6 +37,7 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
logger.debug(f"Content is already a list but no cache_control found")
return message
# Increased min chars threshold to be more selective about what gets cached
if len(str(content)) < min_chars_for_cache:
logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}")
return message
@ -68,11 +69,17 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
max_messages_to_cache: int = 4) -> List[Dict[str, Any]]:
max_messages_to_cache: int = 2) -> List[Dict[str, Any]]:
if not messages:
return messages
resolved_model = get_resolved_model_id(model_name)
model_lower = resolved_model.lower()
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
logger.debug(f"Model {resolved_model} doesn't need cache_control blocks")
return messages
logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})")
formatted_messages = []
@ -88,23 +95,29 @@ def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
formatted_messages.append(message)
continue
if cache_count < max_messages_to_cache:
logger.debug(f"Processing message {i+1}/{len(messages)} for caching")
total_cached = already_cached_count + cache_count
if total_cached < max_messages_to_cache:
logger.debug(f"Processing message {i+1}/{len(messages)} for caching (total cached: {total_cached})")
formatted_message = format_message_with_cache(message, resolved_model)
if formatted_message != message:
cache_count += 1
logger.info(f"✅ Cache applied to message {i+1}")
logger.info(f"✅ Cache applied to message {i+1} (total cached: {already_cached_count + cache_count})")
formatted_messages.append(formatted_message)
else:
logger.debug(f"Skipping cache for message {i+1} - limit reached (total cached: {total_cached})")
formatted_messages.append(message)
if cache_count > 0 or already_cached_count > 0:
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached for model {resolved_model}")
total_final = cache_count + already_cached_count
if total_final > 0:
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached, {total_final} total for model {resolved_model}")
else:
logger.debug(f" No messages needed caching for model {resolved_model}")
if total_final > max_messages_to_cache:
logger.warning(f"⚠️ Total cached messages ({total_final}) exceeds limit ({max_messages_to_cache})")
return formatted_messages
@ -135,3 +148,46 @@ def needs_cache_probe(model_name: str) -> bool:
return any(provider in model_lower for provider in streaming_cache_issues)
def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]:
resolved_model = get_resolved_model_id(model_name)
model_lower = resolved_model.lower()
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
return messages
cache_block_count = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, list) and content:
if isinstance(content[0], dict) and 'cache_control' in content[0]:
cache_block_count += 1
if cache_block_count <= max_blocks:
logger.debug(f"✅ Cache validation passed: {cache_block_count}/{max_blocks} blocks")
return messages
logger.warning(f"⚠️ Cache validation failed: {cache_block_count}/{max_blocks} blocks. Removing excess cache blocks.")
fixed_messages = []
blocks_seen = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, list) and content:
if isinstance(content[0], dict) and 'cache_control' in content[0]:
blocks_seen += 1
if blocks_seen > max_blocks:
logger.info(f"🔧 Removing cache_control from message {blocks_seen} (role: {msg.get('role')})")
new_content = [{k: v for k, v in content[0].items() if k != 'cache_control'}]
fixed_messages.append({**msg, 'content': new_content})
else:
fixed_messages.append(msg)
else:
fixed_messages.append(msg)
else:
fixed_messages.append(msg)
logger.info(f"✅ Fixed cache blocks: {max_blocks}/{cache_block_count} blocks retained")
return fixed_messages

View File

@ -1 +0,0 @@