Merge branch 'main' of https://github.com/kortix-ai/suna into fix-ui-bugs

This commit is contained in:
Saumya 2025-10-02 21:51:07 +05:30
commit c00bc82b2b
4 changed files with 164 additions and 72 deletions

View File

@ -88,9 +88,15 @@ class ContextManager:
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -112,9 +118,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -136,9 +148,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -188,15 +206,10 @@ class ContextManager:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5, actual_total_tokens: Optional[int] = None, system_prompt: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Compress the messages WITHOUT applying caching during iterations.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
Caching should be applied ONCE at the end by the caller, not during compression.
"""
# Get model-specific token limits from constants
context_window = model_manager.get_context_window(llm_model)
@ -218,24 +231,46 @@ class ContextManager:
result = messages
result = self.remove_meta_messages(result)
# Calculate initial token count - just conversation + system prompt, NO caching overhead
print(f"actual_total_tokens: {actual_total_tokens}")
if actual_total_tokens is not None:
uncompressed_total_token_count = actual_total_tokens
else:
print("no actual_total_tokens")
# Count conversation + system prompt WITHOUT caching
if system_prompt:
uncompressed_total_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Initial token count (no caching): {uncompressed_total_token_count}")
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
# Apply compression
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
compressed_token_count = token_counter(model=llm_model, messages=result)
# Recalculate WITHOUT caching overhead
if system_prompt:
compressed_total = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
compressed_total = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_token_count} tokens")
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_total} token")
# Recurse if still too large
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
logger.warning(f"Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(result, llm_model, max_tokens, system_prompt=system_prompt)
return result
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
if compressed_total > max_tokens:
logger.warning(f"Further compression needed: {compressed_total} > {max_tokens}")
# Recursive call - still NO caching
result = self.compress_messages(
result, llm_model, max_tokens,
token_threshold // 2, max_iterations - 1,
compressed_total, system_prompt,
)
return self.middle_out_messages(result)
@ -245,7 +280,8 @@ class ContextManager:
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
min_messages_to_keep: int = 10,
system_prompt: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
@ -263,15 +299,19 @@ class ContextManager:
result = self.remove_meta_messages(result)
# Early exit if no compression needed
if system_prompt:
initial_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
conversation_messages = result[1:] if system_message else result
system_message = system_prompt
conversation_messages = result
safety_limit = 500
current_token_count = initial_token_count
@ -302,8 +342,13 @@ class ContextManager:
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
# Prepare final result - return only conversation messages (matches compress_messages pattern)
final_messages = conversation_messages
# Log with system prompt included for accurate token reporting
if system_message:
final_token_count = token_counter(model=llm_model, messages=[system_message] + final_messages)
else:
final_token_count = token_counter(model=llm_model, messages=final_messages)
logger.info(f"Context compression (omit): {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")

View File

@ -232,6 +232,12 @@ def apply_anthropic_caching_strategy(
This prevents cache invalidation while optimizing for context window utilization
and cost efficiency across different conversation patterns.
"""
# DEBUG: Count message roles to verify tool results are included
message_roles = [msg.get('role', 'unknown') for msg in conversation_messages]
role_counts = {}
for role in message_roles:
role_counts[role] = role_counts.get(role, 0) + 1
logger.debug(f"🔍 CACHING INPUT: {len(conversation_messages)} messages - Roles: {role_counts}")
if not conversation_messages:
conversation_messages = []
@ -256,10 +262,15 @@ def apply_anthropic_caching_strategy(
# Calculate mathematically optimized cache threshold
if cache_threshold_tokens is None:
# Include system prompt tokens in calculation for accurate density (like compression does)
# Use token_counter on combined messages to match compression's calculation method
from litellm import token_counter
total_tokens = token_counter(model=model_name, messages=[working_system_prompt] + conversation_messages) if conversation_messages else 0
cache_threshold_tokens = calculate_optimal_cache_threshold(
context_window_tokens,
len(conversation_messages),
get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0
total_tokens # Now includes system prompt for accurate density calculation
)
logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages")
@ -307,6 +318,7 @@ def apply_anthropic_caching_strategy(
max_cacheable_tokens = int(context_window_tokens * 0.8)
if total_conversation_tokens <= max_cacheable_tokens:
logger.debug(f"Conversation fits within cache limits - use chunked approach")
# Conversation fits within cache limits - use chunked approach
chunks_created = create_conversation_chunks(
conversation_messages,
@ -350,6 +362,7 @@ def create_conversation_chunks(
Final messages are NEVER cached to prevent cache invalidation.
Returns number of cache blocks created.
"""
logger.debug(f"Creating conversation chunks - chunk threshold: {chunk_threshold_tokens}, max blocks: {max_blocks}")
if not messages or max_blocks <= 0:
return 0

View File

@ -17,6 +17,7 @@ from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from core.services.langfuse import langfuse
from datetime import datetime, timezone
from core.billing.billing_integration import billing_integration
from litellm.utils import token_counter
ToolChoice = Literal["auto", "required", "none"]
@ -305,8 +306,11 @@ class ThreadManager:
if ENABLE_CONTEXT_MANAGER:
logger.debug(f"Context manager enabled, compressing {len(messages)} messages")
context_manager = ContextManager()
compressed_messages = context_manager.compress_messages(
messages, llm_model, max_tokens=llm_max_tokens
messages, llm_model, max_tokens=llm_max_tokens,
actual_total_tokens=None, # Will be calculated inside
system_prompt=system_prompt # KEY FIX: No caching during compression
)
logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages")
messages = compressed_messages
@ -340,6 +344,10 @@ class ThreadManager:
except Exception as e:
logger.warning(f"Failed to update Langfuse generation: {e}")
# Log final prepared messages token count
final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens")
# Make LLM call
try:
llm_response = await make_llm_api_call(

View File

@ -59,20 +59,46 @@ function preprocessTextOnlyTools(content: string): string {
return content || '';
}
// Handle new function calls format for text-only tools - extract text parameter content
// Complete XML format
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
// For ask/complete tools, we need to preserve them if they have attachments
// Only strip them if they don't have attachments parameter
// Handle new function calls format - only strip if no attachments
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="present_presentation">[\s\S]*?<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
// Handle streaming/partial XML for message tools - extract text parameter content even if incomplete
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
// Handle streaming/partial XML for message tools - only strip if no attachments visible yet
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="present_presentation">[\s\S]*?<parameter name="text">([\s\S]*?)$/gi, '$1');
// Also handle old format for backward compatibility
content = content.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, '$1');
content = content.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, '$1');
// Also handle old format - only strip if no attachments attribute
content = content.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, (match) => {
if (match.match(/<ask[^>]*attachments=/i)) return match;
return match.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, '$1');
});
content = content.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, (match) => {
if (match.match(/<complete[^>]*attachments=/i)) return match;
return match.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, '$1');
});
content = content.replace(/<present_presentation[^>]*>([\s\S]*?)<\/present_presentation>/gi, '$1');
return content;
}