mirror of https://github.com/kortix-ai/suna.git
Merge pull request #562 from tnfssc/feat/context-compression
This commit is contained in:
commit
16d4237e49
|
@ -18,6 +18,7 @@ from agent.tools.sb_shell_tool import SandboxShellTool
|
|||
from agent.tools.sb_files_tool import SandboxFilesTool
|
||||
from agent.tools.sb_browser_tool import SandboxBrowserTool
|
||||
from agent.tools.data_providers_tool import DataProvidersTool
|
||||
from agent.tools.expand_msg_tool import ExpandMessageTool
|
||||
from agent.prompt import get_system_prompt
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_account_id_from_thread
|
||||
|
@ -75,6 +76,7 @@ async def run_agent(
|
|||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||
# Add data providers tool if RapidAPI key is available
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
import json
|
||||
|
||||
class ExpandMessageTool(Tool):
|
||||
"""Tool for expanding a previous message to the user."""
|
||||
|
||||
def __init__(self, thread_id: str, thread_manager: ThreadManager):
|
||||
super().__init__()
|
||||
self.thread_manager = thread_manager
|
||||
self.thread_id = thread_id
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "expand_message",
|
||||
"description": "Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the message to expand. Must be a UUID."
|
||||
}
|
||||
},
|
||||
"required": ["message_id"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="expand-message",
|
||||
mappings=[
|
||||
{"param_name": "message_id", "node_type": "attribute", "path": "."}
|
||||
],
|
||||
example='''
|
||||
Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation. The message_id must be a valid UUID.
|
||||
|
||||
<!-- Use expand-message when you need to expand a message that was truncated in the previous conversation -->
|
||||
<!-- Use this tool when you need to create reports or analyze the data that resides in a truncated message -->
|
||||
<!-- Examples of when to use expand-message: -->
|
||||
<!-- The message was truncated in the earlier conversation -->
|
||||
|
||||
<expand-message message_id="ecde3a4c-c7dc-4776-ae5c-8209517c5576"></expand-message>
|
||||
'''
|
||||
)
|
||||
async def expand_message(self, message_id: str) -> ToolResult:
|
||||
"""Expand a message from the previous conversation with the user.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the message to expand
|
||||
|
||||
Returns:
|
||||
ToolResult indicating the message was successfully expanded
|
||||
"""
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
message = await client.table('messages').select('*').eq('message_id', message_id).eq('thread_id', self.thread_id).execute()
|
||||
|
||||
if not message.data or len(message.data) == 0:
|
||||
return self.fail_response(f"Message with ID {message_id} not found in thread {self.thread_id}")
|
||||
|
||||
message_data = message.data[0]
|
||||
message_content = message_data['content']
|
||||
final_content = message_content
|
||||
if isinstance(message_content, dict) and 'content' in message_content:
|
||||
final_content = message_content['content']
|
||||
elif isinstance(message_content, str):
|
||||
try:
|
||||
parsed_content = json.loads(message_content)
|
||||
if isinstance(parsed_content, dict) and 'content' in parsed_content:
|
||||
final_content = parsed_content['content']
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return self.success_response({"status": "Message expanded successfully.", "message": final_content})
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error expanding message: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test_expand_message_tool():
|
||||
expand_message_tool = ExpandMessageTool()
|
||||
|
||||
# Test expand message
|
||||
expand_message_result = await expand_message_tool.expand_message(
|
||||
message_id="004ab969-ef9a-4656-8aba-e392345227cd"
|
||||
)
|
||||
print("Expand message result:", expand_message_result)
|
||||
|
||||
asyncio.run(test_expand_message_tool())
|
|
@ -119,7 +119,8 @@ class ThreadManager:
|
|||
client = await self.db.client
|
||||
|
||||
try:
|
||||
result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
||||
# result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
||||
result = await client.table('messages').select('*').eq('thread_id', thread_id).eq('is_llm_message', True).order('created_at').execute()
|
||||
|
||||
# Parse the returned data which might be stringified JSON
|
||||
if not result.data:
|
||||
|
@ -128,23 +129,17 @@ class ThreadManager:
|
|||
# Return properly parsed JSON objects
|
||||
messages = []
|
||||
for item in result.data:
|
||||
if isinstance(item, str):
|
||||
if isinstance(item['content'], str):
|
||||
try:
|
||||
parsed_item = json.loads(item)
|
||||
parsed_item = json.loads(item['content'])
|
||||
parsed_item['message_id'] = item['message_id']
|
||||
messages.append(parsed_item)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse message: {item}")
|
||||
logger.error(f"Failed to parse message: {item['content']}")
|
||||
else:
|
||||
messages.append(item)
|
||||
|
||||
# Ensure tool_calls have properly formatted function arguments
|
||||
for message in messages:
|
||||
if message.get('tool_calls'):
|
||||
for tool_call in message['tool_calls']:
|
||||
if isinstance(tool_call, dict) and 'function' in tool_call:
|
||||
# Ensure function.arguments is a string
|
||||
if 'arguments' in tool_call['function'] and not isinstance(tool_call['function']['arguments'], str):
|
||||
tool_call['function']['arguments'] = json.dumps(tool_call['function']['arguments'])
|
||||
content = item['content']
|
||||
content['message_id'] = item['message_id']
|
||||
messages.append(content)
|
||||
|
||||
return messages
|
||||
|
||||
|
@ -327,6 +322,26 @@ Here are the XML tools available with examples:
|
|||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
|
||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
|
||||
if uncompressed_total_token_count > llm_max_tokens:
|
||||
_i = 0 # Count the number of ToolResult messages
|
||||
for msg in reversed(prepared_messages): # Start from the end and work backwards
|
||||
if "content" in msg and msg['content'] and "ToolResult" in msg['content']: # Only compress ToolResult messages
|
||||
_i += 1 # Count the number of ToolResult messages
|
||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||
if msg_token_count > 5000: # If the message is too long
|
||||
if _i > 1: # If this is not the most recent ToolResult message
|
||||
message_id = msg.get('message_id') # Get the message_id
|
||||
if message_id:
|
||||
msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
|
||||
else:
|
||||
msg["content"] = msg["content"][:300000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
|
||||
|
||||
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
|
||||
|
||||
# 5. Make LLM API call
|
||||
logger.debug("Making LLM API call")
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue