fix: agent stopping in middle of task

This commit is contained in:
Saumya 2025-09-16 01:23:31 +05:30
parent c5a155554d
commit ca84e76c56
3 changed files with 106 additions and 15 deletions

View File

@ -13,7 +13,7 @@ This module provides comprehensive conversation management, including:
import json import json
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable
from core.services.llm import make_llm_api_call from core.services.llm import make_llm_api_call
from core.utils.llm_cache_utils import apply_cache_to_messages from core.utils.llm_cache_utils import apply_cache_to_messages, validate_cache_blocks
from core.agentpress.tool import Tool from core.agentpress.tool import Tool
from core.agentpress.tool_registry import ToolRegistry from core.agentpress.tool_registry import ToolRegistry
from core.agentpress.context_manager import ContextManager from core.agentpress.context_manager import ContextManager
@ -471,14 +471,11 @@ When using the tools:
if isinstance(msg, dict) and msg.get('role') == 'user': if isinstance(msg, dict) and msg.get('role') == 'user':
last_user_index = i last_user_index = i
# Add all messages (temporary messages are no longer used)
prepared_messages.extend(messages) prepared_messages.extend(messages)
# Add partial assistant content for auto-continue context (without saving to DB)
if auto_continue_count > 0 and continuous_state.get('accumulated_content'): if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
partial_content = continuous_state.get('accumulated_content', '') partial_content = continuous_state.get('accumulated_content', '')
# Create temporary assistant message with just the text content
temporary_assistant_message = { temporary_assistant_message = {
"role": "assistant", "role": "assistant",
"content": partial_content "content": partial_content
@ -491,7 +488,47 @@ When using the tools:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas() openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas") logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model) if token_count < 80_000:
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
prepared_messages = validate_cache_blocks(prepared_messages, llm_model)
else:
logger.warning(f"⚠️ Skipping cache formatting due to high token count: {token_count}")
try:
final_token_count = token_counter(model=llm_model, messages=prepared_messages)
if final_token_count != token_count:
logger.info(f"📊 Final token count: {final_token_count} (initial was {token_count})")
from core.ai_models import model_manager
context_window = model_manager.get_context_window(llm_model)
if context_window >= 200_000:
safe_limit = 168_000
elif context_window >= 100_000:
safe_limit = context_window - 20_000
else:
safe_limit = context_window - 10_000
if final_token_count > safe_limit:
logger.warning(f"⚠️ Token count {final_token_count} exceeds safe limit {safe_limit}, compressing messages...")
prepared_messages = self.context_manager.compress_messages(
prepared_messages,
llm_model,
max_tokens=safe_limit
)
compressed_token_count = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"✅ Compressed messages: {final_token_count}{compressed_token_count} tokens")
if compressed_token_count > safe_limit:
logger.error(f"❌ Still over limit after compression: {compressed_token_count} > {safe_limit}")
prepared_messages = self.context_manager.compress_messages_by_omitting_messages(
prepared_messages,
llm_model,
max_tokens=safe_limit - 10_000
)
except Exception as e:
logger.error(f"Error in token checking/compression: {str(e)}")
system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system') system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system')
if system_count > 1: if system_count > 1:
@ -508,7 +545,6 @@ When using the tools:
prepared_messages = filtered_messages prepared_messages = filtered_messages
logger.info(f"🔧 Reduced to 1 system message") logger.info(f"🔧 Reduced to 1 system message")
# Debug: Log what we're sending
logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM") logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM")
for i, msg in enumerate(prepared_messages): for i, msg in enumerate(prepared_messages):
role = msg.get('role', 'unknown') role = msg.get('role', 'unknown')
@ -538,7 +574,7 @@ When using the tools:
) )
llm_response = await make_llm_api_call( llm_response = await make_llm_api_call(
prepared_messages, # Pass the potentially modified messages prepared_messages,
llm_model, llm_model,
temperature=llm_temperature, temperature=llm_temperature,
max_tokens=llm_max_tokens, max_tokens=llm_max_tokens,

View File

@ -20,7 +20,7 @@ def get_resolved_model_id(model_name: str) -> str:
return model_name return model_name
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 3584) -> Dict[str, Any]: def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 10000) -> Dict[str, Any]:
if not message or not isinstance(message, dict): if not message or not isinstance(message, dict):
logger.debug(f"Skipping cache format: message is not a dict") logger.debug(f"Skipping cache format: message is not a dict")
return message return message
@ -37,6 +37,7 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
logger.debug(f"Content is already a list but no cache_control found") logger.debug(f"Content is already a list but no cache_control found")
return message return message
# Increased min chars threshold to be more selective about what gets cached
if len(str(content)) < min_chars_for_cache: if len(str(content)) < min_chars_for_cache:
logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}") logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}")
return message return message
@ -68,11 +69,17 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str, def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
max_messages_to_cache: int = 4) -> List[Dict[str, Any]]: max_messages_to_cache: int = 2) -> List[Dict[str, Any]]:
if not messages: if not messages:
return messages return messages
resolved_model = get_resolved_model_id(model_name) resolved_model = get_resolved_model_id(model_name)
model_lower = resolved_model.lower()
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
logger.debug(f"Model {resolved_model} doesn't need cache_control blocks")
return messages
logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})") logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})")
formatted_messages = [] formatted_messages = []
@ -88,23 +95,29 @@ def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
formatted_messages.append(message) formatted_messages.append(message)
continue continue
if cache_count < max_messages_to_cache: total_cached = already_cached_count + cache_count
logger.debug(f"Processing message {i+1}/{len(messages)} for caching") if total_cached < max_messages_to_cache:
logger.debug(f"Processing message {i+1}/{len(messages)} for caching (total cached: {total_cached})")
formatted_message = format_message_with_cache(message, resolved_model) formatted_message = format_message_with_cache(message, resolved_model)
if formatted_message != message: if formatted_message != message:
cache_count += 1 cache_count += 1
logger.info(f"✅ Cache applied to message {i+1}") logger.info(f"✅ Cache applied to message {i+1} (total cached: {already_cached_count + cache_count})")
formatted_messages.append(formatted_message) formatted_messages.append(formatted_message)
else: else:
logger.debug(f"Skipping cache for message {i+1} - limit reached (total cached: {total_cached})")
formatted_messages.append(message) formatted_messages.append(message)
if cache_count > 0 or already_cached_count > 0: total_final = cache_count + already_cached_count
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached for model {resolved_model}") if total_final > 0:
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached, {total_final} total for model {resolved_model}")
else: else:
logger.debug(f" No messages needed caching for model {resolved_model}") logger.debug(f" No messages needed caching for model {resolved_model}")
if total_final > max_messages_to_cache:
logger.warning(f"⚠️ Total cached messages ({total_final}) exceeds limit ({max_messages_to_cache})")
return formatted_messages return formatted_messages
@ -135,3 +148,46 @@ def needs_cache_probe(model_name: str) -> bool:
return any(provider in model_lower for provider in streaming_cache_issues) return any(provider in model_lower for provider in streaming_cache_issues)
def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]:
resolved_model = get_resolved_model_id(model_name)
model_lower = resolved_model.lower()
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
return messages
cache_block_count = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, list) and content:
if isinstance(content[0], dict) and 'cache_control' in content[0]:
cache_block_count += 1
if cache_block_count <= max_blocks:
logger.debug(f"✅ Cache validation passed: {cache_block_count}/{max_blocks} blocks")
return messages
logger.warning(f"⚠️ Cache validation failed: {cache_block_count}/{max_blocks} blocks. Removing excess cache blocks.")
fixed_messages = []
blocks_seen = 0
for msg in messages:
content = msg.get('content')
if isinstance(content, list) and content:
if isinstance(content[0], dict) and 'cache_control' in content[0]:
blocks_seen += 1
if blocks_seen > max_blocks:
logger.info(f"🔧 Removing cache_control from message {blocks_seen} (role: {msg.get('role')})")
new_content = [{k: v for k, v in content[0].items() if k != 'cache_control'}]
fixed_messages.append({**msg, 'content': new_content})
else:
fixed_messages.append(msg)
else:
fixed_messages.append(msg)
else:
fixed_messages.append(msg)
logger.info(f"✅ Fixed cache blocks: {max_blocks}/{cache_block_count} blocks retained")
return fixed_messages

View File

@ -1 +0,0 @@