mirror of https://github.com/kortix-ai/suna.git
fix: agent stopping in middle of task
This commit is contained in:
parent
c5a155554d
commit
ca84e76c56
|
@ -13,7 +13,7 @@ This module provides comprehensive conversation management, including:
|
|||
import json
|
||||
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable
|
||||
from core.services.llm import make_llm_api_call
|
||||
from core.utils.llm_cache_utils import apply_cache_to_messages
|
||||
from core.utils.llm_cache_utils import apply_cache_to_messages, validate_cache_blocks
|
||||
from core.agentpress.tool import Tool
|
||||
from core.agentpress.tool_registry import ToolRegistry
|
||||
from core.agentpress.context_manager import ContextManager
|
||||
|
@ -471,14 +471,11 @@ When using the tools:
|
|||
if isinstance(msg, dict) and msg.get('role') == 'user':
|
||||
last_user_index = i
|
||||
|
||||
# Add all messages (temporary messages are no longer used)
|
||||
prepared_messages.extend(messages)
|
||||
|
||||
# Add partial assistant content for auto-continue context (without saving to DB)
|
||||
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
|
||||
partial_content = continuous_state.get('accumulated_content', '')
|
||||
|
||||
# Create temporary assistant message with just the text content
|
||||
temporary_assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": partial_content
|
||||
|
@ -491,7 +488,47 @@ When using the tools:
|
|||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
|
||||
if token_count < 80_000:
|
||||
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
|
||||
prepared_messages = validate_cache_blocks(prepared_messages, llm_model)
|
||||
else:
|
||||
logger.warning(f"⚠️ Skipping cache formatting due to high token count: {token_count}")
|
||||
|
||||
try:
|
||||
final_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
|
||||
if final_token_count != token_count:
|
||||
logger.info(f"📊 Final token count: {final_token_count} (initial was {token_count})")
|
||||
|
||||
from core.ai_models import model_manager
|
||||
context_window = model_manager.get_context_window(llm_model)
|
||||
|
||||
if context_window >= 200_000:
|
||||
safe_limit = 168_000
|
||||
elif context_window >= 100_000:
|
||||
safe_limit = context_window - 20_000
|
||||
else:
|
||||
safe_limit = context_window - 10_000
|
||||
|
||||
if final_token_count > safe_limit:
|
||||
logger.warning(f"⚠️ Token count {final_token_count} exceeds safe limit {safe_limit}, compressing messages...")
|
||||
prepared_messages = self.context_manager.compress_messages(
|
||||
prepared_messages,
|
||||
llm_model,
|
||||
max_tokens=safe_limit
|
||||
)
|
||||
compressed_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
logger.info(f"✅ Compressed messages: {final_token_count} → {compressed_token_count} tokens")
|
||||
|
||||
if compressed_token_count > safe_limit:
|
||||
logger.error(f"❌ Still over limit after compression: {compressed_token_count} > {safe_limit}")
|
||||
prepared_messages = self.context_manager.compress_messages_by_omitting_messages(
|
||||
prepared_messages,
|
||||
llm_model,
|
||||
max_tokens=safe_limit - 10_000
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in token checking/compression: {str(e)}")
|
||||
|
||||
system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system')
|
||||
if system_count > 1:
|
||||
|
@ -508,7 +545,6 @@ When using the tools:
|
|||
prepared_messages = filtered_messages
|
||||
logger.info(f"🔧 Reduced to 1 system message")
|
||||
|
||||
# Debug: Log what we're sending
|
||||
logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM")
|
||||
for i, msg in enumerate(prepared_messages):
|
||||
role = msg.get('role', 'unknown')
|
||||
|
@ -538,7 +574,7 @@ When using the tools:
|
|||
)
|
||||
|
||||
llm_response = await make_llm_api_call(
|
||||
prepared_messages, # Pass the potentially modified messages
|
||||
prepared_messages,
|
||||
llm_model,
|
||||
temperature=llm_temperature,
|
||||
max_tokens=llm_max_tokens,
|
||||
|
|
|
@ -20,7 +20,7 @@ def get_resolved_model_id(model_name: str) -> str:
|
|||
return model_name
|
||||
|
||||
|
||||
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 3584) -> Dict[str, Any]:
|
||||
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 10000) -> Dict[str, Any]:
|
||||
if not message or not isinstance(message, dict):
|
||||
logger.debug(f"Skipping cache format: message is not a dict")
|
||||
return message
|
||||
|
@ -37,6 +37,7 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
|
|||
logger.debug(f"Content is already a list but no cache_control found")
|
||||
return message
|
||||
|
||||
# Increased min chars threshold to be more selective about what gets cached
|
||||
if len(str(content)) < min_chars_for_cache:
|
||||
logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}")
|
||||
return message
|
||||
|
@ -68,11 +69,17 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
|
|||
|
||||
|
||||
def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
|
||||
max_messages_to_cache: int = 4) -> List[Dict[str, Any]]:
|
||||
max_messages_to_cache: int = 2) -> List[Dict[str, Any]]:
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
resolved_model = get_resolved_model_id(model_name)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
|
||||
logger.debug(f"Model {resolved_model} doesn't need cache_control blocks")
|
||||
return messages
|
||||
|
||||
logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})")
|
||||
|
||||
formatted_messages = []
|
||||
|
@ -88,23 +95,29 @@ def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
|
|||
formatted_messages.append(message)
|
||||
continue
|
||||
|
||||
if cache_count < max_messages_to_cache:
|
||||
logger.debug(f"Processing message {i+1}/{len(messages)} for caching")
|
||||
total_cached = already_cached_count + cache_count
|
||||
if total_cached < max_messages_to_cache:
|
||||
logger.debug(f"Processing message {i+1}/{len(messages)} for caching (total cached: {total_cached})")
|
||||
formatted_message = format_message_with_cache(message, resolved_model)
|
||||
|
||||
if formatted_message != message:
|
||||
cache_count += 1
|
||||
logger.info(f"✅ Cache applied to message {i+1}")
|
||||
logger.info(f"✅ Cache applied to message {i+1} (total cached: {already_cached_count + cache_count})")
|
||||
|
||||
formatted_messages.append(formatted_message)
|
||||
else:
|
||||
logger.debug(f"Skipping cache for message {i+1} - limit reached (total cached: {total_cached})")
|
||||
formatted_messages.append(message)
|
||||
|
||||
if cache_count > 0 or already_cached_count > 0:
|
||||
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached for model {resolved_model}")
|
||||
total_final = cache_count + already_cached_count
|
||||
if total_final > 0:
|
||||
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached, {total_final} total for model {resolved_model}")
|
||||
else:
|
||||
logger.debug(f"ℹ️ No messages needed caching for model {resolved_model}")
|
||||
|
||||
if total_final > max_messages_to_cache:
|
||||
logger.warning(f"⚠️ Total cached messages ({total_final}) exceeds limit ({max_messages_to_cache})")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
|
@ -135,3 +148,46 @@ def needs_cache_probe(model_name: str) -> bool:
|
|||
|
||||
return any(provider in model_lower for provider in streaming_cache_issues)
|
||||
|
||||
|
||||
def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]:
|
||||
resolved_model = get_resolved_model_id(model_name)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
|
||||
return messages
|
||||
|
||||
cache_block_count = 0
|
||||
for msg in messages:
|
||||
content = msg.get('content')
|
||||
if isinstance(content, list) and content:
|
||||
if isinstance(content[0], dict) and 'cache_control' in content[0]:
|
||||
cache_block_count += 1
|
||||
|
||||
if cache_block_count <= max_blocks:
|
||||
logger.debug(f"✅ Cache validation passed: {cache_block_count}/{max_blocks} blocks")
|
||||
return messages
|
||||
|
||||
logger.warning(f"⚠️ Cache validation failed: {cache_block_count}/{max_blocks} blocks. Removing excess cache blocks.")
|
||||
|
||||
fixed_messages = []
|
||||
blocks_seen = 0
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get('content')
|
||||
if isinstance(content, list) and content:
|
||||
if isinstance(content[0], dict) and 'cache_control' in content[0]:
|
||||
blocks_seen += 1
|
||||
if blocks_seen > max_blocks:
|
||||
logger.info(f"🔧 Removing cache_control from message {blocks_seen} (role: {msg.get('role')})")
|
||||
new_content = [{k: v for k, v in content[0].items() if k != 'cache_control'}]
|
||||
fixed_messages.append({**msg, 'content': new_content})
|
||||
else:
|
||||
fixed_messages.append(msg)
|
||||
else:
|
||||
fixed_messages.append(msg)
|
||||
else:
|
||||
fixed_messages.append(msg)
|
||||
|
||||
logger.info(f"✅ Fixed cache blocks: {max_blocks}/{cache_block_count} blocks retained")
|
||||
return fixed_messages
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
|
Loading…
Reference in New Issue