Merge branch 'PRODUCTION' of https://github.com/kortix-ai/suna into toasts-fix

This commit is contained in:
Soumyadas15 2025-06-06 16:09:32 +05:30
commit ed66a0c0f8
5 changed files with 134 additions and 161 deletions

2
.gitignore vendored
View File

@ -198,3 +198,5 @@ rabbitmq_data
.setup_progress
.setup_env.json
backend/.test_token_compression.py

View File

@ -25,6 +25,7 @@ from utils.logger import logger
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from services.langfuse import langfuse
import datetime
from litellm import token_counter
# Type alias for tool choice
ToolChoice = Literal["auto", "required", "none"]
@ -74,6 +75,134 @@ class ThreadManager:
except (json.JSONDecodeError, TypeError):
pass
return False
def _compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
"""Compress the message content."""
# print("max_length", max_length)
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
else:
return msg_content
def _safe_truncate(self, msg_content: Union[str, dict], max_length: int = 200000) -> Union[str, dict]:
"""Truncate the message content safely."""
if isinstance(msg_content, str):
if len(msg_content) > max_length:
return msg_content[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
elif isinstance(msg_content, dict):
if len(json.dumps(msg_content)) > max_length:
return json.dumps(msg_content)[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
else:
return msg_content
def _compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the tool result messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (64 * 1000)):
_i = 0 # Count the number of ToolResult messages
for msg in reversed(messages): # Start from the end and work backwards
if self._is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the user messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
_i = 0 # Count the number of User messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'user': # Only compress User messages
_i += 1 # Count the number of User messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent User message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
"""Compress the assistant messages except the most recent one."""
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
_i = 0 # Count the number of Assistant messages
for msg in reversed(messages): # Start from the end and work backwards
if msg.get('role') == 'assistant': # Only compress Assistant messages
_i += 1 # Count the number of Assistant messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > token_threshold: # If the message is too long
if _i > 1: # If this is not the most recent Assistant message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
else:
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
else:
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
return messages
def _compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: Optional[int] = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
"""Compress the messages.
token_threshold: must be a power of 2
"""
if 'sonnet' in llm_model.lower():
max_tokens = 200 * 1000 - 64000
elif 'gpt' in llm_model.lower():
max_tokens = 128 * 1000 - 28000
elif 'gemini' in llm_model.lower():
max_tokens = 1000 * 1000 - 300000
elif 'deepseek' in llm_model.lower():
max_tokens = 163 * 1000 - 32000
else:
max_tokens = 41 * 1000 - 10000
if max_iterations <= 0:
logger.warning(f"_compress_messages: Max iterations reached, returning uncompressed messages")
return messages
result = messages
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
result = self._compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
result = self._compress_user_messages(result, llm_model, max_tokens, token_threshold)
result = self._compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
compressed_token_count = token_counter(model=llm_model, messages=result)
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
if (compressed_token_count > max_tokens):
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
return result
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager."""
@ -287,7 +416,6 @@ Here are the XML tools available with examples:
# 2. Check token count before proceeding
token_count = 0
try:
from litellm import token_counter
# Use the potentially modified working_system_prompt for token counting
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
token_threshold = self.context_manager.token_threshold
@ -344,25 +472,7 @@ Here are the XML tools available with examples:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
if uncompressed_total_token_count > (llm_max_tokens or (100 * 1000)):
_i = 0 # Count the number of ToolResult messages
for msg in reversed(prepared_messages): # Start from the end and work backwards
if self._is_tool_result_message(msg): # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > 5000: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
else:
msg["content"] = msg["content"][:200000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
prepared_messages = self._compress_messages(prepared_messages, llm_model)
# 5. Make LLM API call
logger.debug("Making LLM API call")

View File

@ -32,7 +32,7 @@ stripe>=12.0.1
dramatiq>=1.17.1
pika>=1.3.2
prometheus-client>=0.21.1
langfuse>=2.60.5
langfuse==2.60.5
httpx>=0.24.0
Pillow>=10.0.0
sentry-sdk[fastapi]>=2.29.1

View File

@ -1247,7 +1247,7 @@ export const listSandboxFiles = async (
return data.files || [];
} catch (error) {
console.error('Failed to list sandbox files:', error);
handleApiError(error, { operation: 'list files', resource: `directory ${path}` });
// handleApiError(error, { operation: 'list files', resource: `directory ${path}` });
throw error;
}
};
@ -1475,8 +1475,6 @@ export const checkApiHealth = async (): Promise<HealthCheckResponse> => {
return response.json();
} catch (error) {
console.error('API health check failed:', error);
handleApiError(error, { operation: 'check system health', resource: 'system status' });
throw error;
}
};
@ -1783,7 +1781,6 @@ export const checkBillingStatus = async (): Promise<BillingStatusResponse> => {
return response.json();
} catch (error) {
console.error('Failed to check billing status:', error);
handleApiError(error, { operation: 'check billing status', resource: 'account status' });
throw error;
}
};

View File

@ -1,136 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify enhanced response processor for agent builder tools.
"""
import asyncio
import json
from backend.agentpress.response_processor import ResponseProcessor
from backend.agentpress.tool_registry import ToolRegistry
from backend.agentpress.tool import ToolResult
class MockTool:
"""Mock tool for testing."""
def success_response(self, data):
return ToolResult(success=True, output=json.dumps(data, indent=2))
def fail_response(self, msg):
return ToolResult(success=False, output=msg)
async def mock_add_message(thread_id, type, content, is_llm_message, metadata=None):
"""Mock add message callback."""
return {
"message_id": "test-message-id",
"thread_id": thread_id,
"type": type,
"content": content,
"is_llm_message": is_llm_message,
"metadata": metadata or {}
}
def test_update_agent_response():
"""Test update_agent tool response formatting."""
# Create mock tool result for update_agent
mock_tool = MockTool()
update_result = mock_tool.success_response({
"message": "Agent updated successfully",
"updated_fields": ["name", "description", "system_prompt"],
"agent": {
"agent_id": "test-agent-123",
"name": "Research Assistant",
"description": "An AI assistant specialized in research",
"system_prompt": "You are a research assistant with expertise in gathering, analyzing, and synthesizing information from various sources.",
"agentpress_tools": {
"web_search": {"enabled": True, "description": "Search the web"},
"sb_files": {"enabled": True, "description": "File operations"}
},
"configured_mcps": [
{"name": "Exa Search", "qualifiedName": "exa", "enabledTools": ["search"]}
],
"avatar": "🔬",
"avatar_color": "#4F46E5"
}
})
# Test with agent builder mode
tool_registry = ToolRegistry()
processor = ResponseProcessor(
tool_registry=tool_registry,
add_message_callback=mock_add_message,
is_agent_builder=True,
target_agent_id="test-agent-123"
)
tool_call = {
"function_name": "update_agent",
"xml_tag_name": "update_agent",
"arguments": {"name": "Research Assistant"}
}
structured_result = processor._create_structured_tool_result(tool_call, update_result)
print("=== Agent Builder Mode - Update Agent Tool Response ===")
print(structured_result["summary"])
print("\n" + "="*60 + "\n")
# Test without agent builder mode
processor_normal = ResponseProcessor(
tool_registry=tool_registry,
add_message_callback=mock_add_message,
is_agent_builder=False
)
structured_result_normal = processor_normal._create_structured_tool_result(tool_call, update_result)
print("=== Normal Mode - Update Agent Tool Response ===")
print(structured_result_normal["summary"])
print("\n" + "="*60 + "\n")
def test_get_current_agent_config_response():
"""Test get_current_agent_config tool response formatting."""
mock_tool = MockTool()
config_result = mock_tool.success_response({
"summary": "Agent 'Research Assistant' has 2 tools enabled and 1 MCP servers configured.",
"configuration": {
"agent_id": "test-agent-123",
"name": "Research Assistant",
"description": "An AI assistant specialized in research",
"system_prompt": "You are a research assistant with expertise in gathering, analyzing, and synthesizing information from various sources. Your approach is thorough and methodical.",
"agentpress_tools": {
"web_search": {"enabled": True, "description": "Search the web"},
"sb_files": {"enabled": False, "description": "File operations"}
},
"configured_mcps": [],
"avatar": "🔬",
"avatar_color": "#4F46E5"
}
})
tool_registry = ToolRegistry()
processor = ResponseProcessor(
tool_registry=tool_registry,
add_message_callback=mock_add_message,
is_agent_builder=True,
target_agent_id="test-agent-123"
)
tool_call = {
"function_name": "get_current_agent_config",
"xml_tag_name": "get_current_agent_config",
"arguments": {}
}
structured_result = processor._create_structured_tool_result(tool_call, config_result)
print("=== Agent Builder Mode - Get Current Agent Config Response ===")
print(structured_result["summary"])
print("\n" + "="*60 + "\n")
if __name__ == "__main__":
print("Testing Enhanced Response Processor for Agent Builder Tools\n")
test_update_agent_response()
test_get_current_agent_config_response()
print("✅ All tests completed!")