mirror of https://github.com/kortix-ai/suna.git
Merge pull request #562 from tnfssc/feat/context-compression
This commit is contained in:
commit
16d4237e49
|
@ -18,6 +18,7 @@ from agent.tools.sb_shell_tool import SandboxShellTool
|
||||||
from agent.tools.sb_files_tool import SandboxFilesTool
|
from agent.tools.sb_files_tool import SandboxFilesTool
|
||||||
from agent.tools.sb_browser_tool import SandboxBrowserTool
|
from agent.tools.sb_browser_tool import SandboxBrowserTool
|
||||||
from agent.tools.data_providers_tool import DataProvidersTool
|
from agent.tools.data_providers_tool import DataProvidersTool
|
||||||
|
from agent.tools.expand_msg_tool import ExpandMessageTool
|
||||||
from agent.prompt import get_system_prompt
|
from agent.prompt import get_system_prompt
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.auth_utils import get_account_id_from_thread
|
from utils.auth_utils import get_account_id_from_thread
|
||||||
|
@ -75,6 +76,7 @@ async def run_agent(
|
||||||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||||
|
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
# Add data providers tool if RapidAPI key is available
|
# Add data providers tool if RapidAPI key is available
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
||||||
|
from agentpress.thread_manager import ThreadManager
|
||||||
|
import json
|
||||||
|
|
||||||
|
class ExpandMessageTool(Tool):
|
||||||
|
"""Tool for expanding a previous message to the user."""
|
||||||
|
|
||||||
|
def __init__(self, thread_id: str, thread_manager: ThreadManager):
|
||||||
|
super().__init__()
|
||||||
|
self.thread_manager = thread_manager
|
||||||
|
self.thread_id = thread_id
|
||||||
|
|
||||||
|
@openapi_schema({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "expand_message",
|
||||||
|
"description": "Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the message to expand. Must be a UUID."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["message_id"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
@xml_schema(
|
||||||
|
tag_name="expand-message",
|
||||||
|
mappings=[
|
||||||
|
{"param_name": "message_id", "node_type": "attribute", "path": "."}
|
||||||
|
],
|
||||||
|
example='''
|
||||||
|
Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation. The message_id must be a valid UUID.
|
||||||
|
|
||||||
|
<!-- Use expand-message when you need to expand a message that was truncated in the previous conversation -->
|
||||||
|
<!-- Use this tool when you need to create reports or analyze the data that resides in a truncated message -->
|
||||||
|
<!-- Examples of when to use expand-message: -->
|
||||||
|
<!-- The message was truncated in the earlier conversation -->
|
||||||
|
|
||||||
|
<expand-message message_id="ecde3a4c-c7dc-4776-ae5c-8209517c5576"></expand-message>
|
||||||
|
'''
|
||||||
|
)
|
||||||
|
async def expand_message(self, message_id: str) -> ToolResult:
|
||||||
|
"""Expand a message from the previous conversation with the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: The ID of the message to expand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResult indicating the message was successfully expanded
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self.thread_manager.db.client
|
||||||
|
message = await client.table('messages').select('*').eq('message_id', message_id).eq('thread_id', self.thread_id).execute()
|
||||||
|
|
||||||
|
if not message.data or len(message.data) == 0:
|
||||||
|
return self.fail_response(f"Message with ID {message_id} not found in thread {self.thread_id}")
|
||||||
|
|
||||||
|
message_data = message.data[0]
|
||||||
|
message_content = message_data['content']
|
||||||
|
final_content = message_content
|
||||||
|
if isinstance(message_content, dict) and 'content' in message_content:
|
||||||
|
final_content = message_content['content']
|
||||||
|
elif isinstance(message_content, str):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(message_content)
|
||||||
|
if isinstance(parsed_content, dict) and 'content' in parsed_content:
|
||||||
|
final_content = parsed_content['content']
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.success_response({"status": "Message expanded successfully.", "message": final_content})
|
||||||
|
except Exception as e:
|
||||||
|
return self.fail_response(f"Error expanding message: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def test_expand_message_tool():
|
||||||
|
expand_message_tool = ExpandMessageTool()
|
||||||
|
|
||||||
|
# Test expand message
|
||||||
|
expand_message_result = await expand_message_tool.expand_message(
|
||||||
|
message_id="004ab969-ef9a-4656-8aba-e392345227cd"
|
||||||
|
)
|
||||||
|
print("Expand message result:", expand_message_result)
|
||||||
|
|
||||||
|
asyncio.run(test_expand_message_tool())
|
|
@ -119,7 +119,8 @@ class ThreadManager:
|
||||||
client = await self.db.client
|
client = await self.db.client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
# result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
||||||
|
result = await client.table('messages').select('*').eq('thread_id', thread_id).eq('is_llm_message', True).order('created_at').execute()
|
||||||
|
|
||||||
# Parse the returned data which might be stringified JSON
|
# Parse the returned data which might be stringified JSON
|
||||||
if not result.data:
|
if not result.data:
|
||||||
|
@ -128,23 +129,17 @@ class ThreadManager:
|
||||||
# Return properly parsed JSON objects
|
# Return properly parsed JSON objects
|
||||||
messages = []
|
messages = []
|
||||||
for item in result.data:
|
for item in result.data:
|
||||||
if isinstance(item, str):
|
if isinstance(item['content'], str):
|
||||||
try:
|
try:
|
||||||
parsed_item = json.loads(item)
|
parsed_item = json.loads(item['content'])
|
||||||
|
parsed_item['message_id'] = item['message_id']
|
||||||
messages.append(parsed_item)
|
messages.append(parsed_item)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"Failed to parse message: {item}")
|
logger.error(f"Failed to parse message: {item['content']}")
|
||||||
else:
|
else:
|
||||||
messages.append(item)
|
content = item['content']
|
||||||
|
content['message_id'] = item['message_id']
|
||||||
# Ensure tool_calls have properly formatted function arguments
|
messages.append(content)
|
||||||
for message in messages:
|
|
||||||
if message.get('tool_calls'):
|
|
||||||
for tool_call in message['tool_calls']:
|
|
||||||
if isinstance(tool_call, dict) and 'function' in tool_call:
|
|
||||||
# Ensure function.arguments is a string
|
|
||||||
if 'arguments' in tool_call['function'] and not isinstance(tool_call['function']['arguments'], str):
|
|
||||||
tool_call['function']['arguments'] = json.dumps(tool_call['function']['arguments'])
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@ -327,6 +322,26 @@ Here are the XML tools available with examples:
|
||||||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||||
|
|
||||||
|
|
||||||
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||||
|
|
||||||
|
if uncompressed_total_token_count > llm_max_tokens:
|
||||||
|
_i = 0 # Count the number of ToolResult messages
|
||||||
|
for msg in reversed(prepared_messages): # Start from the end and work backwards
|
||||||
|
if "content" in msg and msg['content'] and "ToolResult" in msg['content']: # Only compress ToolResult messages
|
||||||
|
_i += 1 # Count the number of ToolResult messages
|
||||||
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||||
|
if msg_token_count > 5000: # If the message is too long
|
||||||
|
if _i > 1: # If this is not the most recent ToolResult message
|
||||||
|
message_id = msg.get('message_id') # Get the message_id
|
||||||
|
if message_id:
|
||||||
|
msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
|
||||||
|
else:
|
||||||
|
msg["content"] = msg["content"][:300000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
|
||||||
|
|
||||||
|
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
|
||||||
|
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
|
||||||
|
|
||||||
# 5. Make LLM API call
|
# 5. Make LLM API call
|
||||||
logger.debug("Making LLM API call")
|
logger.debug("Making LLM API call")
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue