mirror of https://github.com/kortix-ai/suna.git
Merge branch 'PRODUCTION' of https://github.com/kortix-ai/suna into toasts-fix
This commit is contained in:
commit
ed66a0c0f8
|
@ -198,3 +198,5 @@ rabbitmq_data
|
|||
.setup_progress
|
||||
|
||||
.setup_env.json
|
||||
|
||||
backend/.test_token_compression.py
|
||||
|
|
|
@ -25,6 +25,7 @@ from utils.logger import logger
|
|||
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
||||
from services.langfuse import langfuse
|
||||
import datetime
|
||||
from litellm import token_counter
|
||||
|
||||
# Type alias for tool choice
|
||||
ToolChoice = Literal["auto", "required", "none"]
|
||||
|
@ -74,6 +75,134 @@ class ThreadManager:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return False
|
||||
|
||||
def _compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
|
||||
"""Compress the message content."""
|
||||
# print("max_length", max_length)
|
||||
if isinstance(msg_content, str):
|
||||
if len(msg_content) > max_length:
|
||||
return msg_content[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
|
||||
else:
|
||||
return msg_content
|
||||
elif isinstance(msg_content, dict):
|
||||
if len(json.dumps(msg_content)) > max_length:
|
||||
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message"
|
||||
else:
|
||||
return msg_content
|
||||
|
||||
def _safe_truncate(self, msg_content: Union[str, dict], max_length: int = 200000) -> Union[str, dict]:
|
||||
"""Truncate the message content safely."""
|
||||
if isinstance(msg_content, str):
|
||||
if len(msg_content) > max_length:
|
||||
return msg_content[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||
else:
|
||||
return msg_content
|
||||
elif isinstance(msg_content, dict):
|
||||
if len(json.dumps(msg_content)) > max_length:
|
||||
return json.dumps(msg_content)[:max_length] + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||
else:
|
||||
return msg_content
|
||||
|
||||
def _compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the tool result messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
if uncompressed_total_token_count > (max_tokens or (64 * 1000)):
|
||||
_i = 0 # Count the number of ToolResult messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if self._is_tool_result_message(msg): # Only compress ToolResult messages
|
||||
_i += 1 # Count the number of ToolResult messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent ToolResult message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
return messages
|
||||
|
||||
def _compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the user messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of User messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if msg.get('role') == 'user': # Only compress User messages
|
||||
_i += 1 # Count the number of User messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent User message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
return messages
|
||||
|
||||
def _compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
||||
"""Compress the assistant messages except the most recent one."""
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of Assistant messages
|
||||
for msg in reversed(messages): # Start from the end and work backwards
|
||||
if msg.get('role') == 'assistant': # Only compress Assistant messages
|
||||
_i += 1 # Count the number of Assistant messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > token_threshold: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent Assistant message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
||||
else:
|
||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||
else:
|
||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
||||
|
||||
return messages
|
||||
|
||||
def _compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: Optional[int] = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Compress the messages.
|
||||
token_threshold: must be a power of 2
|
||||
"""
|
||||
|
||||
if 'sonnet' in llm_model.lower():
|
||||
max_tokens = 200 * 1000 - 64000
|
||||
elif 'gpt' in llm_model.lower():
|
||||
max_tokens = 128 * 1000 - 28000
|
||||
elif 'gemini' in llm_model.lower():
|
||||
max_tokens = 1000 * 1000 - 300000
|
||||
elif 'deepseek' in llm_model.lower():
|
||||
max_tokens = 163 * 1000 - 32000
|
||||
else:
|
||||
max_tokens = 41 * 1000 - 10000
|
||||
|
||||
if max_iterations <= 0:
|
||||
logger.warning(f"_compress_messages: Max iterations reached, returning uncompressed messages")
|
||||
return messages
|
||||
|
||||
result = messages
|
||||
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||
|
||||
result = self._compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
|
||||
result = self._compress_user_messages(result, llm_model, max_tokens, token_threshold)
|
||||
result = self._compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
|
||||
|
||||
compressed_token_count = token_counter(model=llm_model, messages=result)
|
||||
|
||||
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
||||
|
||||
if (compressed_token_count > max_tokens):
|
||||
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
|
||||
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
|
||||
|
||||
return result
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
"""Add a tool to the ThreadManager."""
|
||||
|
@ -287,7 +416,6 @@ Here are the XML tools available with examples:
|
|||
# 2. Check token count before proceeding
|
||||
token_count = 0
|
||||
try:
|
||||
from litellm import token_counter
|
||||
# Use the potentially modified working_system_prompt for token counting
|
||||
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
token_threshold = self.context_manager.token_threshold
|
||||
|
@ -344,25 +472,7 @@ Here are the XML tools available with examples:
|
|||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
|
||||
if uncompressed_total_token_count > (llm_max_tokens or (100 * 1000)):
|
||||
_i = 0 # Count the number of ToolResult messages
|
||||
for msg in reversed(prepared_messages): # Start from the end and work backwards
|
||||
if self._is_tool_result_message(msg): # Only compress ToolResult messages
|
||||
_i += 1 # Count the number of ToolResult messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > 5000: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent ToolResult message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
|
||||
else:
|
||||
msg["content"] = msg["content"][:200000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
|
||||
|
||||
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
|
||||
prepared_messages = self._compress_messages(prepared_messages, llm_model)
|
||||
|
||||
# 5. Make LLM API call
|
||||
logger.debug("Making LLM API call")
|
||||
|
|
|
@ -32,7 +32,7 @@ stripe>=12.0.1
|
|||
dramatiq>=1.17.1
|
||||
pika>=1.3.2
|
||||
prometheus-client>=0.21.1
|
||||
langfuse>=2.60.5
|
||||
langfuse==2.60.5
|
||||
httpx>=0.24.0
|
||||
Pillow>=10.0.0
|
||||
sentry-sdk[fastapi]>=2.29.1
|
||||
|
|
|
@ -1247,7 +1247,7 @@ export const listSandboxFiles = async (
|
|||
return data.files || [];
|
||||
} catch (error) {
|
||||
console.error('Failed to list sandbox files:', error);
|
||||
handleApiError(error, { operation: 'list files', resource: `directory ${path}` });
|
||||
// handleApiError(error, { operation: 'list files', resource: `directory ${path}` });
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
@ -1475,8 +1475,6 @@ export const checkApiHealth = async (): Promise<HealthCheckResponse> => {
|
|||
|
||||
return response.json();
|
||||
} catch (error) {
|
||||
console.error('API health check failed:', error);
|
||||
handleApiError(error, { operation: 'check system health', resource: 'system status' });
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
@ -1783,7 +1781,6 @@ export const checkBillingStatus = async (): Promise<BillingStatusResponse> => {
|
|||
return response.json();
|
||||
} catch (error) {
|
||||
console.error('Failed to check billing status:', error);
|
||||
handleApiError(error, { operation: 'check billing status', resource: 'account status' });
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,136 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify enhanced response processor for agent builder tools.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from backend.agentpress.response_processor import ResponseProcessor
|
||||
from backend.agentpress.tool_registry import ToolRegistry
|
||||
from backend.agentpress.tool import ToolResult
|
||||
|
||||
class MockTool:
|
||||
"""Mock tool for testing."""
|
||||
|
||||
def success_response(self, data):
|
||||
return ToolResult(success=True, output=json.dumps(data, indent=2))
|
||||
|
||||
def fail_response(self, msg):
|
||||
return ToolResult(success=False, output=msg)
|
||||
|
||||
async def mock_add_message(thread_id, type, content, is_llm_message, metadata=None):
|
||||
"""Mock add message callback."""
|
||||
return {
|
||||
"message_id": "test-message-id",
|
||||
"thread_id": thread_id,
|
||||
"type": type,
|
||||
"content": content,
|
||||
"is_llm_message": is_llm_message,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
def test_update_agent_response():
|
||||
"""Test update_agent tool response formatting."""
|
||||
|
||||
# Create mock tool result for update_agent
|
||||
mock_tool = MockTool()
|
||||
update_result = mock_tool.success_response({
|
||||
"message": "Agent updated successfully",
|
||||
"updated_fields": ["name", "description", "system_prompt"],
|
||||
"agent": {
|
||||
"agent_id": "test-agent-123",
|
||||
"name": "Research Assistant",
|
||||
"description": "An AI assistant specialized in research",
|
||||
"system_prompt": "You are a research assistant with expertise in gathering, analyzing, and synthesizing information from various sources.",
|
||||
"agentpress_tools": {
|
||||
"web_search": {"enabled": True, "description": "Search the web"},
|
||||
"sb_files": {"enabled": True, "description": "File operations"}
|
||||
},
|
||||
"configured_mcps": [
|
||||
{"name": "Exa Search", "qualifiedName": "exa", "enabledTools": ["search"]}
|
||||
],
|
||||
"avatar": "🔬",
|
||||
"avatar_color": "#4F46E5"
|
||||
}
|
||||
})
|
||||
|
||||
# Test with agent builder mode
|
||||
tool_registry = ToolRegistry()
|
||||
processor = ResponseProcessor(
|
||||
tool_registry=tool_registry,
|
||||
add_message_callback=mock_add_message,
|
||||
is_agent_builder=True,
|
||||
target_agent_id="test-agent-123"
|
||||
)
|
||||
|
||||
tool_call = {
|
||||
"function_name": "update_agent",
|
||||
"xml_tag_name": "update_agent",
|
||||
"arguments": {"name": "Research Assistant"}
|
||||
}
|
||||
|
||||
structured_result = processor._create_structured_tool_result(tool_call, update_result)
|
||||
|
||||
print("=== Agent Builder Mode - Update Agent Tool Response ===")
|
||||
print(structured_result["summary"])
|
||||
print("\n" + "="*60 + "\n")
|
||||
|
||||
# Test without agent builder mode
|
||||
processor_normal = ResponseProcessor(
|
||||
tool_registry=tool_registry,
|
||||
add_message_callback=mock_add_message,
|
||||
is_agent_builder=False
|
||||
)
|
||||
|
||||
structured_result_normal = processor_normal._create_structured_tool_result(tool_call, update_result)
|
||||
|
||||
print("=== Normal Mode - Update Agent Tool Response ===")
|
||||
print(structured_result_normal["summary"])
|
||||
print("\n" + "="*60 + "\n")
|
||||
|
||||
def test_get_current_agent_config_response():
|
||||
"""Test get_current_agent_config tool response formatting."""
|
||||
|
||||
mock_tool = MockTool()
|
||||
config_result = mock_tool.success_response({
|
||||
"summary": "Agent 'Research Assistant' has 2 tools enabled and 1 MCP servers configured.",
|
||||
"configuration": {
|
||||
"agent_id": "test-agent-123",
|
||||
"name": "Research Assistant",
|
||||
"description": "An AI assistant specialized in research",
|
||||
"system_prompt": "You are a research assistant with expertise in gathering, analyzing, and synthesizing information from various sources. Your approach is thorough and methodical.",
|
||||
"agentpress_tools": {
|
||||
"web_search": {"enabled": True, "description": "Search the web"},
|
||||
"sb_files": {"enabled": False, "description": "File operations"}
|
||||
},
|
||||
"configured_mcps": [],
|
||||
"avatar": "🔬",
|
||||
"avatar_color": "#4F46E5"
|
||||
}
|
||||
})
|
||||
|
||||
tool_registry = ToolRegistry()
|
||||
processor = ResponseProcessor(
|
||||
tool_registry=tool_registry,
|
||||
add_message_callback=mock_add_message,
|
||||
is_agent_builder=True,
|
||||
target_agent_id="test-agent-123"
|
||||
)
|
||||
|
||||
tool_call = {
|
||||
"function_name": "get_current_agent_config",
|
||||
"xml_tag_name": "get_current_agent_config",
|
||||
"arguments": {}
|
||||
}
|
||||
|
||||
structured_result = processor._create_structured_tool_result(tool_call, config_result)
|
||||
|
||||
print("=== Agent Builder Mode - Get Current Agent Config Response ===")
|
||||
print(structured_result["summary"])
|
||||
print("\n" + "="*60 + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Enhanced Response Processor for Agent Builder Tools\n")
|
||||
test_update_agent_response()
|
||||
test_get_current_agent_config_response()
|
||||
print("✅ All tests completed!")
|
Loading…
Reference in New Issue