Merge branch 'main' of https://github.com/kortix-ai/suna into fix-ui-bugs

This commit is contained in:
Saumya 2025-10-02 21:51:07 +05:30
commit c00bc82b2b
4 changed files with 164 additions and 72 deletions

View File

@ -88,9 +88,15 @@ class ContextManager:
else:
return msg_content
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -112,9 +118,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -136,9 +148,15 @@ class ContextManager:
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
return messages
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000, uncompressed_total_token_count: Optional[int] = None) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one.
Compression is deterministic (simple truncation), ensuring consistent results across requests.
This allows prompt caching (applied later) to produce cache hits on identical compressed content.
"""
if uncompressed_total_token_count is None:
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
max_tokens_value = max_tokens or (100 * 1000)
if uncompressed_total_token_count > max_tokens_value:
@ -188,15 +206,10 @@ class ContextManager:
result.append(msg)
return result
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5, actual_total_tokens: Optional[int] = None, system_prompt: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Compress the messages WITHOUT applying caching during iterations.
Args:
messages: List of messages to compress
llm_model: Model name for token counting
max_tokens: Maximum allowed tokens
token_threshold: Token threshold for individual message compression (must be a power of 2)
max_iterations: Maximum number of compression iterations
Caching should be applied ONCE at the end by the caller, not during compression.
"""
# Get model-specific token limits from constants
context_window = model_manager.get_context_window(llm_model)
@ -218,24 +231,46 @@ class ContextManager:
result = messages
result = self.remove_meta_messages(result)
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
# Calculate initial token count - just conversation + system prompt, NO caching overhead
print(f"actual_total_tokens: {actual_total_tokens}")
if actual_total_tokens is not None:
uncompressed_total_token_count = actual_total_tokens
else:
print("no actual_total_tokens")
# Count conversation + system prompt WITHOUT caching
if system_prompt:
uncompressed_total_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Initial token count (no caching): {uncompressed_total_token_count}")
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
# Apply compression
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold, uncompressed_total_token_count)
compressed_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_token_count} tokens")
# Recalculate WITHOUT caching overhead
if system_prompt:
compressed_total = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
compressed_total = token_counter(model=llm_model, messages=result)
logger.info(f"Context compression: {uncompressed_total_token_count} -> {compressed_total} token")
# Recurse if still too large
if max_iterations <= 0:
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
logger.warning(f"Max iterations reached, omitting messages")
result = self.compress_messages_by_omitting_messages(result, llm_model, max_tokens, system_prompt=system_prompt)
return result
if compressed_token_count > max_tokens:
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
if compressed_total > max_tokens:
logger.warning(f"Further compression needed: {compressed_total} > {max_tokens}")
# Recursive call - still NO caching
result = self.compress_messages(
result, llm_model, max_tokens,
token_threshold // 2, max_iterations - 1,
compressed_total, system_prompt,
)
return self.middle_out_messages(result)
@ -245,7 +280,8 @@ class ContextManager:
llm_model: str,
max_tokens: Optional[int] = 41000,
removal_batch_size: int = 10,
min_messages_to_keep: int = 10
min_messages_to_keep: int = 10,
system_prompt: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Compress the messages by omitting messages from the middle.
@ -263,15 +299,19 @@ class ContextManager:
result = self.remove_meta_messages(result)
# Early exit if no compression needed
initial_token_count = token_counter(model=llm_model, messages=result)
if system_prompt:
initial_token_count = token_counter(model=llm_model, messages=[system_prompt] + result)
else:
initial_token_count = token_counter(model=llm_model, messages=result)
max_allowed_tokens = max_tokens or (100 * 1000)
if initial_token_count <= max_allowed_tokens:
return result
# Separate system message (assumed to be first) from conversation messages
system_message = messages[0] if messages and isinstance(messages[0], dict) and messages[0].get('role') == 'system' else None
conversation_messages = result[1:] if system_message else result
system_message = system_prompt
conversation_messages = result
safety_limit = 500
current_token_count = initial_token_count
@ -302,9 +342,14 @@ class ContextManager:
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
# Prepare final result
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
final_token_count = token_counter(model=llm_model, messages=final_messages)
# Prepare final result - return only conversation messages (matches compress_messages pattern)
final_messages = conversation_messages
# Log with system prompt included for accurate token reporting
if system_message:
final_token_count = token_counter(model=llm_model, messages=[system_message] + final_messages)
else:
final_token_count = token_counter(model=llm_model, messages=final_messages)
logger.info(f"Context compression (omit): {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")

View File

@ -232,6 +232,12 @@ def apply_anthropic_caching_strategy(
This prevents cache invalidation while optimizing for context window utilization
and cost efficiency across different conversation patterns.
"""
# DEBUG: Count message roles to verify tool results are included
message_roles = [msg.get('role', 'unknown') for msg in conversation_messages]
role_counts = {}
for role in message_roles:
role_counts[role] = role_counts.get(role, 0) + 1
logger.debug(f"🔍 CACHING INPUT: {len(conversation_messages)} messages - Roles: {role_counts}")
if not conversation_messages:
conversation_messages = []
@ -256,10 +262,15 @@ def apply_anthropic_caching_strategy(
# Calculate mathematically optimized cache threshold
if cache_threshold_tokens is None:
# Include system prompt tokens in calculation for accurate density (like compression does)
# Use token_counter on combined messages to match compression's calculation method
from litellm import token_counter
total_tokens = token_counter(model=model_name, messages=[working_system_prompt] + conversation_messages) if conversation_messages else 0
cache_threshold_tokens = calculate_optimal_cache_threshold(
context_window_tokens,
len(conversation_messages),
get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0
total_tokens # Now includes system prompt for accurate density calculation
)
logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages")
@ -307,6 +318,7 @@ def apply_anthropic_caching_strategy(
max_cacheable_tokens = int(context_window_tokens * 0.8)
if total_conversation_tokens <= max_cacheable_tokens:
logger.debug(f"Conversation fits within cache limits - use chunked approach")
# Conversation fits within cache limits - use chunked approach
chunks_created = create_conversation_chunks(
conversation_messages,
@ -350,6 +362,7 @@ def create_conversation_chunks(
Final messages are NEVER cached to prevent cache invalidation.
Returns number of cache blocks created.
"""
logger.debug(f"Creating conversation chunks - chunk threshold: {chunk_threshold_tokens}, max blocks: {max_blocks}")
if not messages or max_blocks <= 0:
return 0

View File

@ -17,6 +17,7 @@ from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from core.services.langfuse import langfuse
from datetime import datetime, timezone
from core.billing.billing_integration import billing_integration
from litellm.utils import token_counter
ToolChoice = Literal["auto", "required", "none"]
@ -305,8 +306,11 @@ class ThreadManager:
if ENABLE_CONTEXT_MANAGER:
logger.debug(f"Context manager enabled, compressing {len(messages)} messages")
context_manager = ContextManager()
compressed_messages = context_manager.compress_messages(
messages, llm_model, max_tokens=llm_max_tokens
messages, llm_model, max_tokens=llm_max_tokens,
actual_total_tokens=None, # Will be calculated inside
system_prompt=system_prompt # KEY FIX: No caching during compression
)
logger.debug(f"Context compression completed: {len(messages)} -> {len(compressed_messages)} messages")
messages = compressed_messages
@ -340,6 +344,10 @@ class ThreadManager:
except Exception as e:
logger.warning(f"Failed to update Langfuse generation: {e}")
# Log final prepared messages token count
final_prepared_tokens = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"📤 Final prepared messages being sent to LLM: {final_prepared_tokens} tokens")
# Make LLM call
try:
llm_response = await make_llm_api_call(

View File

@ -53,26 +53,52 @@ export function renderAttachments(attachments: string[], fileViewerHandler?: (fi
// Render Markdown content while preserving XML tags that should be displayed as tool calls
function preprocessTextOnlyTools(content: string): string {
console.log('🔍 preprocessTextOnlyTools called with:', typeof content, content);
if (!content || typeof content !== 'string') {
console.warn('❌ preprocessTextOnlyTools: Invalid content type:', typeof content, content);
return content || '';
}
// Handle new function calls format for text-only tools - extract text parameter content
// Complete XML format
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
// For ask/complete tools, we need to preserve them if they have attachments
// Only strip them if they don't have attachments parameter
// Handle new function calls format - only strip if no attachments
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)<\/parameter>\s*<\/invoke>\s*<\/function_calls>/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="present_presentation">[\s\S]*?<parameter name="text">([\s\S]*?)<\/parameter>[\s\S]*?<\/invoke>\s*<\/function_calls>/gi, '$1');
// Handle streaming/partial XML for message tools - extract text parameter content even if incomplete
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
// Handle streaming/partial XML for message tools - only strip if no attachments visible yet
content = content.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="ask">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, (match) => {
if (match.includes('<parameter name="attachments"')) return match;
return match.replace(/<function_calls>\s*<invoke name="complete">\s*<parameter name="text">([\s\S]*?)$/gi, '$1');
});
content = content.replace(/<function_calls>\s*<invoke name="present_presentation">[\s\S]*?<parameter name="text">([\s\S]*?)$/gi, '$1');
// Also handle old format for backward compatibility
content = content.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, '$1');
content = content.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, '$1');
// Also handle old format - only strip if no attachments attribute
content = content.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, (match) => {
if (match.match(/<ask[^>]*attachments=/i)) return match;
return match.replace(/<ask[^>]*>([\s\S]*?)<\/ask>/gi, '$1');
});
content = content.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, (match) => {
if (match.match(/<complete[^>]*attachments=/i)) return match;
return match.replace(/<complete[^>]*>([\s\S]*?)<\/complete>/gi, '$1');
});
content = content.replace(/<present_presentation[^>]*>([\s\S]*?)<\/present_presentation>/gi, '$1');
return content;
}
@ -88,7 +114,7 @@ export function renderMarkdownContent(
) {
// Preprocess content to convert text-only tools to natural text
content = preprocessTextOnlyTools(content);
// If in debug mode, just display raw content in a pre tag
if (debugMode) {
return (
@ -139,7 +165,7 @@ export function renderMarkdownContent(
{renderAttachments(attachmentArray, fileViewerHandler, sandboxId, project)}
</div>
);
// Also render standalone attachments outside the message
const standaloneAttachments = renderStandaloneAttachments(attachmentArray, fileViewerHandler, sandboxId, project);
if (standaloneAttachments) {
@ -165,7 +191,7 @@ export function renderMarkdownContent(
{renderAttachments(attachmentArray, fileViewerHandler, sandboxId, project)}
</div>
);
// Also render standalone attachments outside the message
const standaloneAttachments = renderStandaloneAttachments(attachmentArray, fileViewerHandler, sandboxId, project);
if (standaloneAttachments) {
@ -268,7 +294,7 @@ export function renderMarkdownContent(
{renderAttachments(attachments, fileViewerHandler, sandboxId, project)}
</div>
);
// Also render standalone attachments outside the message
const standaloneAttachments = renderStandaloneAttachments(attachments, fileViewerHandler, sandboxId, project);
if (standaloneAttachments) {
@ -296,7 +322,7 @@ export function renderMarkdownContent(
{renderAttachments(attachments, fileViewerHandler, sandboxId, project)}
</div>
);
// Also render standalone attachments outside the message
const standaloneAttachments = renderStandaloneAttachments(attachments, fileViewerHandler, sandboxId, project);
if (standaloneAttachments) {
@ -671,13 +697,13 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
// Use merged groups instead of original grouped messages
const finalGroupedMessages = mergedGroups;
// Helper function to add streaming content to groups
const appendStreamingContent = (content: string, isPlayback: boolean = false) => {
const messageId = isPlayback ? 'playbackStreamingText' : 'streamingTextContent';
const metadata = isPlayback ? 'playbackStreamingText' : 'streamingTextContent';
const keySuffix = isPlayback ? 'playback-streaming' : 'streaming';
const lastGroup = finalGroupedMessages.at(-1);
if (!lastGroup || lastGroup.type === 'user') {
// Create new assistant group for streaming content
@ -770,7 +796,7 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
<div key={group.key} className="space-y-3">
{/* All file attachments rendered outside message bubble */}
{renderStandaloneAttachments(attachments as string[], handleOpenFileViewer, sandboxId, project, true)}
<div className="flex justify-end">
<div className="flex max-w-[85%] rounded-3xl rounded-br-lg bg-card border px-4 py-3 break-words overflow-hidden">
<div className="space-y-3 min-w-0 flex-1">
@ -789,7 +815,7 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
// Get agent_id from the first assistant message in this group
const firstAssistantMsg = group.messages.find(m => m.type === 'assistant');
const groupAgentId = firstAssistantMsg?.agent_id;
return (
<div key={group.key} ref={groupIndex === groupedMessages.length - 1 ? latestMessageRef : null}>
<div className="flex flex-col gap-2">
@ -903,7 +929,7 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
// Preprocess content first to remove text-only tool tags
const textToRender = preprocessTextOnlyTools(streamingTextContent || '');
let detectedTag: string | null = null;
let tagStartIndex = -1;
if (textToRender) {
@ -927,11 +953,11 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
}
const textBeforeTag = detectedTag ? textToRender.substring(0, tagStartIndex) : textToRender;
const showCursor =
(streamHookStatus ===
'streaming' ||
streamHookStatus ===
'connecting') &&
!detectedTag;
(streamHookStatus ===
'streaming' ||
streamHookStatus ===
'connecting') &&
!detectedTag;
// Show minimal processing indicator when agent is active but no streaming text after preprocessing
if (!textToRender && (streamHookStatus === 'streaming' || streamHookStatus === 'connecting')) {
@ -946,8 +972,8 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
return (
<>
<StreamingText
content={textBeforeTag}
<StreamingText
content={textBeforeTag}
className="text-sm prose prose-sm dark:prose-invert chat-markdown max-w-none [&>:first-child]:mt-0 prose-headings:mt-3 break-words overflow-wrap-anywhere"
/>
@ -972,7 +998,7 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
{(() => {
// Preprocess content first to remove text-only tool tags
const textToRender = preprocessTextOnlyTools(streamingText || '');
let detectedTag: string | null = null;
let tagStartIndex = -1;
if (textToRender) {
@ -1009,8 +1035,8 @@ export const ThreadContent: React.FC<ThreadContentProps> = ({
</pre>
) : (
<>
<StreamingText
content={textBeforeTag}
<StreamingText
content={textBeforeTag}
className="text-sm prose prose-sm dark:prose-invert chat-markdown max-w-none [&>:first-child]:mt-0 prose-headings:mt-3 break-words overflow-wrap-anywhere"
/>