mirror of https://github.com/kortix-ai/suna.git
fix: agent stopping in middle of task
This commit is contained in:
parent
c5a155554d
commit
ca84e76c56
|
@ -13,7 +13,7 @@ This module provides comprehensive conversation management, including:
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable
|
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal, cast, Callable
|
||||||
from core.services.llm import make_llm_api_call
|
from core.services.llm import make_llm_api_call
|
||||||
from core.utils.llm_cache_utils import apply_cache_to_messages
|
from core.utils.llm_cache_utils import apply_cache_to_messages, validate_cache_blocks
|
||||||
from core.agentpress.tool import Tool
|
from core.agentpress.tool import Tool
|
||||||
from core.agentpress.tool_registry import ToolRegistry
|
from core.agentpress.tool_registry import ToolRegistry
|
||||||
from core.agentpress.context_manager import ContextManager
|
from core.agentpress.context_manager import ContextManager
|
||||||
|
@ -471,14 +471,11 @@ When using the tools:
|
||||||
if isinstance(msg, dict) and msg.get('role') == 'user':
|
if isinstance(msg, dict) and msg.get('role') == 'user':
|
||||||
last_user_index = i
|
last_user_index = i
|
||||||
|
|
||||||
# Add all messages (temporary messages are no longer used)
|
|
||||||
prepared_messages.extend(messages)
|
prepared_messages.extend(messages)
|
||||||
|
|
||||||
# Add partial assistant content for auto-continue context (without saving to DB)
|
|
||||||
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
|
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
|
||||||
partial_content = continuous_state.get('accumulated_content', '')
|
partial_content = continuous_state.get('accumulated_content', '')
|
||||||
|
|
||||||
# Create temporary assistant message with just the text content
|
|
||||||
temporary_assistant_message = {
|
temporary_assistant_message = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": partial_content
|
"content": partial_content
|
||||||
|
@ -491,7 +488,47 @@ When using the tools:
|
||||||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||||
|
|
||||||
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
|
if token_count < 80_000:
|
||||||
|
prepared_messages = apply_cache_to_messages(prepared_messages, llm_model)
|
||||||
|
prepared_messages = validate_cache_blocks(prepared_messages, llm_model)
|
||||||
|
else:
|
||||||
|
logger.warning(f"⚠️ Skipping cache formatting due to high token count: {token_count}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
final_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||||
|
|
||||||
|
if final_token_count != token_count:
|
||||||
|
logger.info(f"📊 Final token count: {final_token_count} (initial was {token_count})")
|
||||||
|
|
||||||
|
from core.ai_models import model_manager
|
||||||
|
context_window = model_manager.get_context_window(llm_model)
|
||||||
|
|
||||||
|
if context_window >= 200_000:
|
||||||
|
safe_limit = 168_000
|
||||||
|
elif context_window >= 100_000:
|
||||||
|
safe_limit = context_window - 20_000
|
||||||
|
else:
|
||||||
|
safe_limit = context_window - 10_000
|
||||||
|
|
||||||
|
if final_token_count > safe_limit:
|
||||||
|
logger.warning(f"⚠️ Token count {final_token_count} exceeds safe limit {safe_limit}, compressing messages...")
|
||||||
|
prepared_messages = self.context_manager.compress_messages(
|
||||||
|
prepared_messages,
|
||||||
|
llm_model,
|
||||||
|
max_tokens=safe_limit
|
||||||
|
)
|
||||||
|
compressed_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||||
|
logger.info(f"✅ Compressed messages: {final_token_count} → {compressed_token_count} tokens")
|
||||||
|
|
||||||
|
if compressed_token_count > safe_limit:
|
||||||
|
logger.error(f"❌ Still over limit after compression: {compressed_token_count} > {safe_limit}")
|
||||||
|
prepared_messages = self.context_manager.compress_messages_by_omitting_messages(
|
||||||
|
prepared_messages,
|
||||||
|
llm_model,
|
||||||
|
max_tokens=safe_limit - 10_000
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in token checking/compression: {str(e)}")
|
||||||
|
|
||||||
system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system')
|
system_count = sum(1 for msg in prepared_messages if msg.get('role') == 'system')
|
||||||
if system_count > 1:
|
if system_count > 1:
|
||||||
|
@ -508,7 +545,6 @@ When using the tools:
|
||||||
prepared_messages = filtered_messages
|
prepared_messages = filtered_messages
|
||||||
logger.info(f"🔧 Reduced to 1 system message")
|
logger.info(f"🔧 Reduced to 1 system message")
|
||||||
|
|
||||||
# Debug: Log what we're sending
|
|
||||||
logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM")
|
logger.info(f"📤 Sending {len(prepared_messages)} messages to LLM")
|
||||||
for i, msg in enumerate(prepared_messages):
|
for i, msg in enumerate(prepared_messages):
|
||||||
role = msg.get('role', 'unknown')
|
role = msg.get('role', 'unknown')
|
||||||
|
@ -538,7 +574,7 @@ When using the tools:
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_response = await make_llm_api_call(
|
llm_response = await make_llm_api_call(
|
||||||
prepared_messages, # Pass the potentially modified messages
|
prepared_messages,
|
||||||
llm_model,
|
llm_model,
|
||||||
temperature=llm_temperature,
|
temperature=llm_temperature,
|
||||||
max_tokens=llm_max_tokens,
|
max_tokens=llm_max_tokens,
|
||||||
|
|
|
@ -20,7 +20,7 @@ def get_resolved_model_id(model_name: str) -> str:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 3584) -> Dict[str, Any]:
|
def format_message_with_cache(message: Dict[str, Any], model_name: str, min_chars_for_cache: int = 10000) -> Dict[str, Any]:
|
||||||
if not message or not isinstance(message, dict):
|
if not message or not isinstance(message, dict):
|
||||||
logger.debug(f"Skipping cache format: message is not a dict")
|
logger.debug(f"Skipping cache format: message is not a dict")
|
||||||
return message
|
return message
|
||||||
|
@ -37,6 +37,7 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
|
||||||
logger.debug(f"Content is already a list but no cache_control found")
|
logger.debug(f"Content is already a list but no cache_control found")
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
# Increased min chars threshold to be more selective about what gets cached
|
||||||
if len(str(content)) < min_chars_for_cache:
|
if len(str(content)) < min_chars_for_cache:
|
||||||
logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}")
|
logger.debug(f"Content too short for caching: {len(str(content))} < {min_chars_for_cache}")
|
||||||
return message
|
return message
|
||||||
|
@ -68,11 +69,17 @@ def format_message_with_cache(message: Dict[str, Any], model_name: str, min_char
|
||||||
|
|
||||||
|
|
||||||
def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
|
def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
|
||||||
max_messages_to_cache: int = 4) -> List[Dict[str, Any]]:
|
max_messages_to_cache: int = 2) -> List[Dict[str, Any]]:
|
||||||
if not messages:
|
if not messages:
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
resolved_model = get_resolved_model_id(model_name)
|
resolved_model = get_resolved_model_id(model_name)
|
||||||
|
model_lower = resolved_model.lower()
|
||||||
|
|
||||||
|
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
|
||||||
|
logger.debug(f"Model {resolved_model} doesn't need cache_control blocks")
|
||||||
|
return messages
|
||||||
|
|
||||||
logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})")
|
logger.info(f"📊 apply_cache_to_messages called with {len(messages)} messages for model: {model_name} (resolved: {resolved_model})")
|
||||||
|
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
@ -88,23 +95,29 @@ def apply_cache_to_messages(messages: List[Dict[str, Any]], model_name: str,
|
||||||
formatted_messages.append(message)
|
formatted_messages.append(message)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cache_count < max_messages_to_cache:
|
total_cached = already_cached_count + cache_count
|
||||||
logger.debug(f"Processing message {i+1}/{len(messages)} for caching")
|
if total_cached < max_messages_to_cache:
|
||||||
|
logger.debug(f"Processing message {i+1}/{len(messages)} for caching (total cached: {total_cached})")
|
||||||
formatted_message = format_message_with_cache(message, resolved_model)
|
formatted_message = format_message_with_cache(message, resolved_model)
|
||||||
|
|
||||||
if formatted_message != message:
|
if formatted_message != message:
|
||||||
cache_count += 1
|
cache_count += 1
|
||||||
logger.info(f"✅ Cache applied to message {i+1}")
|
logger.info(f"✅ Cache applied to message {i+1} (total cached: {already_cached_count + cache_count})")
|
||||||
|
|
||||||
formatted_messages.append(formatted_message)
|
formatted_messages.append(formatted_message)
|
||||||
else:
|
else:
|
||||||
|
logger.debug(f"Skipping cache for message {i+1} - limit reached (total cached: {total_cached})")
|
||||||
formatted_messages.append(message)
|
formatted_messages.append(message)
|
||||||
|
|
||||||
if cache_count > 0 or already_cached_count > 0:
|
total_final = cache_count + already_cached_count
|
||||||
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached for model {resolved_model}")
|
if total_final > 0:
|
||||||
|
logger.info(f"🎯 Caching status: {cache_count} newly cached, {already_cached_count} already cached, {total_final} total for model {resolved_model}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"ℹ️ No messages needed caching for model {resolved_model}")
|
logger.debug(f"ℹ️ No messages needed caching for model {resolved_model}")
|
||||||
|
|
||||||
|
if total_final > max_messages_to_cache:
|
||||||
|
logger.warning(f"⚠️ Total cached messages ({total_final}) exceeds limit ({max_messages_to_cache})")
|
||||||
|
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,3 +148,46 @@ def needs_cache_probe(model_name: str) -> bool:
|
||||||
|
|
||||||
return any(provider in model_lower for provider in streaming_cache_issues)
|
return any(provider in model_lower for provider in streaming_cache_issues)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]:
|
||||||
|
resolved_model = get_resolved_model_id(model_name)
|
||||||
|
model_lower = resolved_model.lower()
|
||||||
|
|
||||||
|
if not any(provider in model_lower for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus']):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
cache_block_count = 0
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get('content')
|
||||||
|
if isinstance(content, list) and content:
|
||||||
|
if isinstance(content[0], dict) and 'cache_control' in content[0]:
|
||||||
|
cache_block_count += 1
|
||||||
|
|
||||||
|
if cache_block_count <= max_blocks:
|
||||||
|
logger.debug(f"✅ Cache validation passed: {cache_block_count}/{max_blocks} blocks")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
logger.warning(f"⚠️ Cache validation failed: {cache_block_count}/{max_blocks} blocks. Removing excess cache blocks.")
|
||||||
|
|
||||||
|
fixed_messages = []
|
||||||
|
blocks_seen = 0
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get('content')
|
||||||
|
if isinstance(content, list) and content:
|
||||||
|
if isinstance(content[0], dict) and 'cache_control' in content[0]:
|
||||||
|
blocks_seen += 1
|
||||||
|
if blocks_seen > max_blocks:
|
||||||
|
logger.info(f"🔧 Removing cache_control from message {blocks_seen} (role: {msg.get('role')})")
|
||||||
|
new_content = [{k: v for k, v in content[0].items() if k != 'cache_control'}]
|
||||||
|
fixed_messages.append({**msg, 'content': new_content})
|
||||||
|
else:
|
||||||
|
fixed_messages.append(msg)
|
||||||
|
else:
|
||||||
|
fixed_messages.append(msg)
|
||||||
|
else:
|
||||||
|
fixed_messages.append(msg)
|
||||||
|
|
||||||
|
logger.info(f"✅ Fixed cache blocks: {max_blocks}/{cache_block_count} blocks retained")
|
||||||
|
return fixed_messages
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
|
|
Loading…
Reference in New Issue