From 960dd671f2913e275857ff3b40e7b5c1c6fedd11 Mon Sep 17 00:00:00 2001 From: sharath <29162020+tnfssc@users.noreply.github.com> Date: Thu, 29 May 2025 13:47:31 +0000 Subject: [PATCH] feat(context-compression): add ExpandMessageTool for message expansion and integrate into agent run process --- backend/agent/run.py | 2 + backend/agent/tools/expand_msg_tool.py | 91 ++++++++++++++++++++++++++ backend/agentpress/thread_manager.py | 43 ++++++++---- 3 files changed, 122 insertions(+), 14 deletions(-) create mode 100644 backend/agent/tools/expand_msg_tool.py diff --git a/backend/agent/run.py b/backend/agent/run.py index d101501f..791ae55c 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -18,6 +18,7 @@ from agent.tools.sb_shell_tool import SandboxShellTool from agent.tools.sb_files_tool import SandboxFilesTool from agent.tools.sb_browser_tool import SandboxBrowserTool from agent.tools.data_providers_tool import DataProvidersTool +from agent.tools.expand_msg_tool import ExpandMessageTool from agent.prompt import get_system_prompt from utils.logger import logger from utils.auth_utils import get_account_id_from_thread @@ -75,6 +76,7 @@ async def run_agent( thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool + thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) # Add data providers tool if RapidAPI key is available diff --git a/backend/agent/tools/expand_msg_tool.py b/backend/agent/tools/expand_msg_tool.py new file mode 100644 index 00000000..340f529e --- /dev/null +++ b/backend/agent/tools/expand_msg_tool.py @@ -0,0 +1,91 @@ +from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema +from agentpress.thread_manager import ThreadManager +import json + +class ExpandMessageTool(Tool): + """Tool for expanding a previous message to the user.""" + + def __init__(self, thread_id: str, thread_manager: ThreadManager): + super().__init__() + self.thread_manager = thread_manager + self.thread_id = thread_id + + @openapi_schema({ + "type": "function", + "function": { + "name": "expand_message", + "description": "Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation.", + "parameters": { + "type": "object", + "properties": { + "message_id": { + "type": "string", + "description": "The ID of the message to expand. Must be a UUID." + } + }, + "required": ["message_id"] + } + } + }) + @xml_schema( + tag_name="expand-message", + mappings=[ + {"param_name": "message_id", "node_type": "attribute", "path": "."} + ], + example=''' +Expand a message from the previous conversation with the user. Use this tool to expand a message that was truncated in the earlier conversation. The message_id must be a valid UUID. + + + + + + + + ''' + ) + async def expand_message(self, message_id: str) -> ToolResult: + """Expand a message from the previous conversation with the user. + + Args: + message_id: The ID of the message to expand + + Returns: + ToolResult indicating the message was successfully expanded + """ + try: + client = await self.thread_manager.db.client + message = await client.table('messages').select('*').eq('message_id', message_id).eq('thread_id', self.thread_id).execute() + + if not message.data or len(message.data) == 0: + return self.fail_response(f"Message with ID {message_id} not found in thread {self.thread_id}") + + message_data = message.data[0] + message_content = message_data['content'] + final_content = message_content + if isinstance(message_content, dict) and 'content' in message_content: + final_content = message_content['content'] + elif isinstance(message_content, str): + try: + parsed_content = json.loads(message_content) + if isinstance(parsed_content, dict) and 'content' in parsed_content: + final_content = parsed_content['content'] + except json.JSONDecodeError: + pass + + return self.success_response({"status": "Message expanded successfully.", "message": final_content}) + except Exception as e: + return self.fail_response(f"Error expanding message: {str(e)}") + +if __name__ == "__main__": + import asyncio + + async def test_expand_message_tool(): + expand_message_tool = ExpandMessageTool() + + # Test expand message + expand_message_result = await expand_message_tool.expand_message( + message_id="004ab969-ef9a-4656-8aba-e392345227cd" + ) + print("Expand message result:", expand_message_result) + + asyncio.run(test_expand_message_tool()) diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index 15f76ff5..6dc07765 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -119,7 +119,8 @@ class ThreadManager: client = await self.db.client try: - result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute() + # result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute() + result = await client.table('messages').select('*').eq('thread_id', thread_id).eq('is_llm_message', True).order('created_at').execute() # Parse the returned data which might be stringified JSON if not result.data: @@ -128,23 +129,17 @@ class ThreadManager: # Return properly parsed JSON objects messages = [] for item in result.data: - if isinstance(item, str): + if isinstance(item['content'], str): try: - parsed_item = json.loads(item) + parsed_item = json.loads(item['content']) + parsed_item['message_id'] = item['message_id'] messages.append(parsed_item) except json.JSONDecodeError: - logger.error(f"Failed to parse message: {item}") + logger.error(f"Failed to parse message: {item['content']}") else: - messages.append(item) - - # Ensure tool_calls have properly formatted function arguments - for message in messages: - if message.get('tool_calls'): - for tool_call in message['tool_calls']: - if isinstance(tool_call, dict) and 'function' in tool_call: - # Ensure function.arguments is a string - if 'arguments' in tool_call['function'] and not isinstance(tool_call['function']['arguments'], str): - tool_call['function']['arguments'] = json.dumps(tool_call['function']['arguments']) + content = item['content'] + content['message_id'] = item['message_id'] + messages.append(content) return messages @@ -327,6 +322,26 @@ Here are the XML tools available with examples: openapi_tool_schemas = self.tool_registry.get_openapi_schemas() logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas") + + uncompressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages) + + if uncompressed_total_token_count > llm_max_tokens: + _i = 0 # Count the number of ToolResult messages + for msg in reversed(prepared_messages): # Start from the end and work backwards + if "content" in msg and msg['content'] and "ToolResult" in msg['content']: # Only compress ToolResult messages + _i += 1 # Count the number of ToolResult messages + msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message + if msg_token_count > 5000: # If the message is too long + if _i > 1: # If this is not the most recent ToolResult message + message_id = msg.get('message_id') # Get the message_id + if message_id: + msg["content"] = msg["content"][:10000] + "... (truncated)" + f"\n\nThis message is too long, use the expand-message tool with message_id \"{message_id}\" to see the full message" # Truncate the message + else: + msg["content"] = msg["content"][:300000] + f"\n\nThis message is too long, repeat relevant information in your response to remember it" # Truncate to 300k characters to avoid overloading the context at once, but don't truncate otherwise + + compressed_total_token_count = token_counter(model=llm_model, messages=prepared_messages) + logger.info(f"token_compression: {uncompressed_total_token_count} -> {compressed_total_token_count}") # Log the token compression for debugging later + # 5. Make LLM API call logger.debug("Making LLM API call") try: