Merge pull request #562 from tnfssc/feat/context-compression

This commit is contained in:
Sharath 2025-05-29 19:34:44 +05:30 committed by GitHub
commit 16d4237e49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 14 deletions

View File

@ -18,6 +18,7 @@ from agent.tools.sb_shell_tool import SandboxShellTool
from agent.tools.sb_files_tool import SandboxFilesTool
from agent.tools.sb_browser_tool import SandboxBrowserTool
from agent.tools.data_providers_tool import DataProvidersTool
from agent.tools.expand_msg_tool import ExpandMessageTool
from agent.prompt import get_system_prompt
from utils.logger import logger
from utils.auth_utils import get_account_id_from_thread
@ -75,6 +76,7 @@ async def run_agent(
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
# Add data providers tool if RapidAPI key is available

View File

@ -0,0 +1,91 @@
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.thread_manager import ThreadManager
import json
class ExpandMessageTool(Tool):
"""Tool for expanding a previous message to the user."""
def __init__(self, thread_id: str, thread_manager: ThreadManager):
super().__init__()
self.thread_manager = thread_manager
self.thread_id = thread_id
@openapi_schema({
"type": "function",
"function": {
"name": "expand_message",
"description": "Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation.",
"parameters": {
"type": "object",
"properties": {
"message_id": {
"type": "string",
"description": "The ID of the message to expand. Must be a UUID."
}
},
"required": ["message_id"]
}
}
})
@xml_schema(
tag_name="expand-message",
mappings=[
{"param_name": "message_id", "node_type": "attribute", "path": "."}
],
example='''
Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation. The message_id must be a valid UUID.
<!-- Use expand-message when you need to expand a message that was truncated in the previous conversation -->
<!-- Use this tool when you need to create reports or analyze the data that resides in a truncated message -->
<!-- Examples of when to use expand-message: -->
<!-- The message was truncated in the earlier conversation -->
<expand-message message_id="ecde3a4c-c7dc-4776-ae5c-8209517c5576"></expand-message>
'''
)
async def expand_message(self, message_id: str) -> ToolResult:
"""Expand a message from the previous conversation with the user.
Args:
message_id: The ID of the message to expand
Returns:
ToolResult indicating the message was successfully expanded
"""
try:
client = await self.thread_manager.db.client
message = await client.table('messages').select('*').eq('message_id', message_id).eq('thread_id', self.thread_id).execute()
if not message.data or len(message.data) == 0:
return self.fail_response(f"Message with ID {message_id} not found in thread {self.thread_id}")
message_data = message.data[0]
message_content = message_data['content']
final_content = message_content
if isinstance(message_content, dict) and 'content' in message_content:
final_content = message_content['content']
elif isinstance(message_content, str):
try:
parsed_content = json.loads(message_content)
if isinstance(parsed_content, dict) and 'content' in parsed_content:
final_content = parsed_content['content']
except json.JSONDecodeError:
pass
return self.success_response({"status": "Message expanded successfully.", "message": final_content})
except Exception as e:
return self.fail_response(f"Error expanding message: {str(e)}")
if __name__ == "__main__":
import asyncio
async def test_expand_message_tool():
expand_message_tool = ExpandMessageTool()
# Test expand message
expand_message_result = await expand_message_tool.expand_message(
message_id="004ab969-ef9a-4656-8aba-e392345227cd"
)
print("Expand message result:", expand_message_result)
asyncio.run(test_expand_message_tool())

View File

@ -119,7 +119,8 @@ class ThreadManager:
client = await self.db.client
try:
result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
# result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
result = await client.table('messages').select('*').eq('thread_id', thread_id).eq('is_llm_message', True).order('created_at').execute()
# Parse the returned data which might be stringified JSON
if not result.data:
@ -128,23 +129,17 @@ class ThreadManager:
# Return properly parsed JSON objects
messages = []
for item in result.data:
if isinstance(item, str):
if isinstance(item['content'], str):
try:
parsed_item = json.loads(item)
parsed_item = json.loads(item['content'])
parsed_item['message_id'] = item['message_id']
messages.append(parsed_item)
except json.JSONDecodeError:
logger.error(f"Failed to parse message: {item}")
logger.error(f"Failed to parse message: {item['content']}")
else:
messages.append(item)
# Ensure tool_calls have properly formatted function arguments
for message in messages:
if message.get('tool_calls'):
for tool_call in message['tool_calls']:
if isinstance(tool_call, dict) and 'function' in tool_call:
# Ensure function.arguments is a string
if 'arguments' in tool_call['function'] and not isinstance(tool_call['function']['arguments'], str):
tool_call['function']['arguments'] = json.dumps(tool_call['function']['arguments'])
content = item['content']
content['message_id'] = item['message_id']
messages.append(content)
return messages
@ -327,6 +322,26 @@ Here are the XML tools available with examples:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
if uncompressed_total_token_count > llm_max_tokens:
_i = 0 # Count the number of ToolResult messages
for msg in reversed(prepared_messages): # Start from the end and work backwards
if "content" in msg and msg['content'] and "ToolResult" in msg['content']: # Only compress ToolResult messages
_i += 1 # Count the number of ToolResult messages
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
if msg_token_count > 5000: # If the message is too long
if _i > 1: # If this is not the most recent ToolResult message
message_id = msg.get('message_id') # Get the message_id
if message_id:
msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message
else:
msg["content"] = msg["content"][:300000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise
compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages)
logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later
# 5. Make LLM API call
logger.debug("Making LLM API call")
try: