From 6a59b8e1106a5048d3a045b45750d46b2cc91e9e Mon Sep 17 00:00:00 2001 From: Adam Cohen Hillel Date: Thu, 17 Apr 2025 00:16:53 +0100 Subject: [PATCH] token manager --- backend/agentpress/context_manager.py | 298 ++++++++++++++++++ backend/agentpress/thread_manager.py | 59 +++- .../20250416133920_agentpress_schema.sql | 35 +- 3 files changed, 380 insertions(+), 12 deletions(-) create mode 100644 backend/agentpress/context_manager.py diff --git a/backend/agentpress/context_manager.py b/backend/agentpress/context_manager.py new file mode 100644 index 00000000..e88ce361 --- /dev/null +++ b/backend/agentpress/context_manager.py @@ -0,0 +1,298 @@ +""" +Context Management for AgentPress Threads. + +This module handles token counting and thread summarization to prevent +reaching the context window limitations of LLM models. +""" + +import json +from typing import List, Dict, Any, Optional + +from litellm import token_counter, completion, completion_cost +from services.supabase import DBConnection +from services.llm import make_llm_api_call +from utils.logger import logger + +# Constants for token management +DEFAULT_TOKEN_THRESHOLD = 120000 # 80k tokens threshold for summarization +SUMMARY_TARGET_TOKENS = 10000 # Target ~10k tokens for the summary message +RESERVE_TOKENS = 5000 # Reserve tokens for new messages + +class ContextManager: + """Manages thread context including token counting and summarization.""" + + def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD): + """Initialize the ContextManager. + + Args: + token_threshold: Token count threshold to trigger summarization + """ + self.db = DBConnection() + self.token_threshold = token_threshold + + async def get_thread_token_count(self, thread_id: str) -> int: + """Get the current token count for a thread using LiteLLM. + + Args: + thread_id: ID of the thread to analyze + + Returns: + The total token count for relevant messages in the thread + """ + logger.debug(f"Getting token count for thread {thread_id}") + + try: + # Get messages for the thread + messages = await self.get_messages_for_summarization(thread_id) + + if not messages: + logger.debug(f"No messages found for thread {thread_id}") + return 0 + + # Use litellm's token_counter for accurate model-specific counting + # This is much more accurate than the SQL-based estimation + token_count = token_counter(model="gpt-4", messages=messages) + + logger.info(f"Thread {thread_id} has {token_count} tokens (calculated with litellm)") + return token_count + + except Exception as e: + logger.error(f"Error getting token count: {str(e)}") + return 0 + + async def get_messages_for_summarization(self, thread_id: str) -> List[Dict[str, Any]]: + """Get all LLM messages from the thread that need to be summarized. + + This gets messages after the most recent summary or all messages if + no summary exists. Unlike get_llm_messages, this includes ALL messages + since the last summary, even if we're generating a new summary. + + Args: + thread_id: ID of the thread to get messages from + + Returns: + List of message objects to summarize + """ + logger.debug(f"Getting messages for summarization for thread {thread_id}") + client = await self.db.client + + try: + # Find the most recent summary message + summary_result = await client.table('messages').select('created_at') \ + .eq('thread_id', thread_id) \ + .eq('type', 'summary') \ + .eq('is_llm_message', True) \ + .order('created_at', desc=True) \ + .limit(1) \ + .execute() + + # Get messages after the most recent summary or all messages if no summary + if summary_result.data and len(summary_result.data) > 0: + last_summary_time = summary_result.data[0]['created_at'] + logger.debug(f"Found last summary at {last_summary_time}") + + # Get all messages after the summary, but NOT including the summary itself + messages_result = await client.table('messages').select('*') \ + .eq('thread_id', thread_id) \ + .eq('is_llm_message', True) \ + .gt('created_at', last_summary_time) \ + .order('created_at') \ + .execute() + else: + logger.debug("No previous summary found, getting all messages") + # Get all messages + messages_result = await client.table('messages').select('*') \ + .eq('thread_id', thread_id) \ + .eq('is_llm_message', True) \ + .order('created_at') \ + .execute() + + # Parse the message content if needed + messages = [] + for msg in messages_result.data: + # Skip existing summary messages - we don't want to summarize summaries + if msg.get('type') == 'summary': + logger.debug(f"Skipping summary message from {msg.get('created_at')}") + continue + + # Parse content if it's a string + content = msg['content'] + if isinstance(content, str): + try: + content = json.loads(content) + except json.JSONDecodeError: + pass # Keep as string if not valid JSON + + # Ensure we have the proper format for the LLM + if 'role' not in content and 'type' in msg: + # Convert message type to role if needed + role = msg['type'] + if role == 'assistant' or role == 'user' or role == 'system' or role == 'tool': + content = {'role': role, 'content': content} + + messages.append(content) + + logger.info(f"Got {len(messages)} messages to summarize for thread {thread_id}") + return messages + + except Exception as e: + logger.error(f"Error getting messages for summarization: {str(e)}", exc_info=True) + return [] + + async def create_summary( + self, + thread_id: str, + messages: List[Dict[str, Any]], + model: str = "gpt-3.5-turbo" + ) -> Optional[Dict[str, Any]]: + """Generate a summary of conversation messages. + + Args: + thread_id: ID of the thread to summarize + messages: Messages to summarize + model: LLM model to use for summarization + + Returns: + Summary message object or None if summarization failed + """ + if not messages: + logger.warning("No messages to summarize") + return None + + logger.info(f"Creating summary for thread {thread_id} with {len(messages)} messages") + + # Create system message with summarization instructions + system_message = { + "role": "system", + "content": f"""You are a specialized summarization assistant. Your task is to create a concise but comprehensive summary of the conversation history. + +The summary should: +1. Preserve all key information including decisions, conclusions, and important context +2. Include any tools that were used and their results +3. Maintain chronological order of events +4. Be presented as a narrated list of key points with section headers +5. Include only factual information from the conversation (no new information) +6. Be concise but detailed enough that the conversation can continue with this summary as context + +VERY IMPORTANT: This summary will replace older parts of the conversation in the LLM's context window, so ensure it contains ALL key information and LATEST STATE OF THE CONVERSATION - SO WE WILL KNOW HOW TO PICK UP WHERE WE LEFT OFF. + + +THE CONVERSATION HISTORY TO SUMMARIZE IS AS FOLLOWS: +=============================================================== +==================== CONVERSATION HISTORY ==================== +{messages} +==================== END OF CONVERSATION HISTORY ==================== +=============================================================== +""" + } + + try: + # Call LLM to generate summary + response = await make_llm_api_call( + model_name=model, + messages=[system_message, {"role": "user", "content": "PLEASE PROVIDE THE SUMMARY NOW."}], + temperature=0, + max_tokens=SUMMARY_TARGET_TOKENS, + stream=False + ) + + if response and hasattr(response, 'choices') and response.choices: + summary_content = response.choices[0].message.content + + # Track token usage + try: + token_count = token_counter(model=model, messages=[{"role": "user", "content": summary_content}]) + cost = completion_cost(model=model, prompt="", completion=summary_content) + logger.info(f"Summary generated with {token_count} tokens at cost ${cost:.6f}") + except Exception as e: + logger.error(f"Error calculating token usage: {str(e)}") + + # Format the summary message with clear beginning and end markers + formatted_summary = f""" +======== CONVERSATION HISTORY SUMMARY ======== + +{summary_content} + +======== END OF SUMMARY ======== + +The above is a summary of the conversation history. The conversation continues below. +""" + + # Format the summary message + summary_message = { + "role": "user", + "content": formatted_summary + } + + return summary_message + else: + logger.error("Failed to generate summary: Invalid response") + return None + + except Exception as e: + logger.error(f"Error creating summary: {str(e)}", exc_info=True) + return None + + async def check_and_summarize_if_needed( + self, + thread_id: str, + add_message_callback, + model: str = "gpt-3.5-turbo", + force: bool = False + ) -> bool: + """Check if thread needs summarization and summarize if so. + + Args: + thread_id: ID of the thread to check + add_message_callback: Callback to add the summary message to the thread + model: LLM model to use for summarization + force: Whether to force summarization regardless of token count + + Returns: + True if summarization was performed, False otherwise + """ + try: + # Get token count using LiteLLM (accurate model-specific counting) + token_count = await self.get_thread_token_count(thread_id) + + # If token count is below threshold and not forcing, no summarization needed + if token_count < self.token_threshold and not force: + logger.debug(f"Thread {thread_id} has {token_count} tokens, below threshold {self.token_threshold}") + return False + + # Log reason for summarization + if force: + logger.info(f"Forced summarization of thread {thread_id} with {token_count} tokens") + else: + logger.info(f"Thread {thread_id} exceeds token threshold ({token_count} >= {self.token_threshold}), summarizing...") + + # Get messages to summarize + messages = await self.get_messages_for_summarization(thread_id) + + # If there are too few messages, don't summarize + if len(messages) < 3: + logger.info(f"Thread {thread_id} has too few messages ({len(messages)}) to summarize") + return False + + # Create summary + summary = await self.create_summary(thread_id, messages, model) + + if summary: + # Add summary message to thread + await add_message_callback( + thread_id=thread_id, + type="summary", + content=summary, + is_llm_message=True, + metadata={"token_count": token_count} + ) + + logger.info(f"Successfully added summary to thread {thread_id}") + return True + else: + logger.error(f"Failed to create summary for thread {thread_id}") + return False + + except Exception as e: + logger.error(f"Error in check_and_summarize_if_needed: {str(e)}", exc_info=True) + return False \ No newline at end of file diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index 014a2b97..cec9d473 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -7,14 +7,15 @@ This module provides comprehensive conversation management, including: - Tool registration and execution - LLM interaction with streaming support - Error handling and cleanup +- Context summarization to manage token limits """ import json -import uuid -from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Tuple, Callable, Literal +from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal from services.llm import make_llm_api_call -from agentpress.tool import Tool, ToolResult +from agentpress.tool import Tool from agentpress.tool_registry import ToolRegistry +from agentpress.context_manager import ContextManager from agentpress.response_processor import ( ResponseProcessor, ProcessorConfig @@ -34,13 +35,16 @@ class ThreadManager: """ def __init__(self): - """Initialize ThreadManager.""" + """Initialize ThreadManager. + + """ self.db = DBConnection() self.tool_registry = ToolRegistry() self.response_processor = ResponseProcessor( tool_registry=self.tool_registry, add_message_callback=self.add_message ) + self.context_manager = ContextManager() def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): """Add a tool to the ThreadManager.""" @@ -88,6 +92,9 @@ class ThreadManager: async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]: """Get all messages for a thread. + This method uses the SQL function which handles context truncation + by considering summary messages. + Args: thread_id: The ID of the thread to get messages for. @@ -220,7 +227,41 @@ Here are the XML tools available with examples: # 1. Get messages from thread for LLM call messages = await self.get_llm_messages(thread_id) - # 2. Prepare messages for LLM call + add temporary message if it exists + # 2. Check token count before proceeding + # Use litellm to count tokens in the messages + token_count = 0 + try: + from litellm import token_counter + token_count = token_counter(model=llm_model, messages=[system_prompt] + messages) + token_threshold = self.context_manager.token_threshold + logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)") + + # If we're over the threshold, summarize the thread + if token_count >= token_threshold: + logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...") + + # Create summary using context manager + summarized = await self.context_manager.check_and_summarize_if_needed( + thread_id=thread_id, + add_message_callback=self.add_message, + model=llm_model, + force=True # Force summarization + ) + + if summarized: + # If summarization was successful, get the updated messages + # This will now include the summary message and only messages after it + logger.info("Summarization complete, fetching updated messages with summary") + messages = await self.get_llm_messages(thread_id) + # Recount tokens after summarization + new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages) + logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}") + else: + logger.warning("Summarization failed or wasn't needed - proceeding with original messages") + except Exception as e: + logger.error(f"Error counting tokens or summarizing: {str(e)}") + + # 3. Prepare messages for LLM call + add temporary message if it exists prepared_messages = [system_prompt] # Find the last user message index @@ -242,19 +283,19 @@ Here are the XML tools available with examples: prepared_messages.append(temp_msg) logger.debug("Added temporary message to the end of prepared messages") - # 3. Create or use processor config - this is now redundant since we handle it above + # 4. Create or use processor config - this is now redundant since we handle it above # but kept for consistency and clarity logger.debug(f"Processor config: XML={processor_config.xml_tool_calling}, Native={processor_config.native_tool_calling}, " f"Execute tools={processor_config.execute_tools}, Strategy={processor_config.tool_execution_strategy}, " f"XML limit={processor_config.max_xml_tool_calls}") - # 4. Prepare tools for LLM call + # 5. Prepare tools for LLM call openapi_tool_schemas = None if processor_config.native_tool_calling: openapi_tool_schemas = self.tool_registry.get_openapi_schemas() logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas") - # 5. Make LLM API call + # 6. Make LLM API call logger.debug("Making LLM API call") try: llm_response = await make_llm_api_call( @@ -272,7 +313,7 @@ Here are the XML tools available with examples: logger.error(f"Failed to make LLM API call: {str(e)}", exc_info=True) raise - # 6. Process LLM response using the ResponseProcessor + # 7. Process LLM response using the ResponseProcessor if stream: logger.debug("Processing streaming response") response_generator = self.response_processor.process_streaming_response( diff --git a/backend/supabase/migrations/20250416133920_agentpress_schema.sql b/backend/supabase/migrations/20250416133920_agentpress_schema.sql index 50ee43ba..a7d8702b 100644 --- a/backend/supabase/migrations/20250416133920_agentpress_schema.sql +++ b/backend/supabase/migrations/20250416133920_agentpress_schema.sql @@ -284,6 +284,8 @@ DECLARE messages_array JSONB := '[]'::JSONB; has_access BOOLEAN; current_role TEXT; + latest_summary_id UUID; + latest_summary_time TIMESTAMP WITH TIME ZONE; BEGIN -- Get current role SELECT current_user INTO current_role; @@ -306,19 +308,46 @@ BEGIN END IF; END IF; + -- Find the latest summary message if it exists + SELECT message_id, created_at + INTO latest_summary_id, latest_summary_time + FROM messages + WHERE thread_id = p_thread_id + AND type = 'summary' + AND is_llm_message = TRUE + ORDER BY created_at DESC + LIMIT 1; + + -- Log whether a summary was found (helpful for debugging) + IF latest_summary_id IS NOT NULL THEN + RAISE NOTICE 'Found latest summary message: id=%, time=%', latest_summary_id, latest_summary_time; + ELSE + RAISE NOTICE 'No summary message found for thread %', p_thread_id; + END IF; + -- Parse content if it's stored as a string and return proper JSON objects WITH parsed_messages AS ( SELECT + message_id, CASE WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb ELSE content END AS parsed_content, - created_at + created_at, + type FROM messages WHERE thread_id = p_thread_id AND is_llm_message = TRUE + AND ( + -- Include the latest summary and all messages after it, + -- or all messages if no summary exists + latest_summary_id IS NULL + OR message_id = latest_summary_id + OR created_at > latest_summary_time + ) + ORDER BY created_at ) - SELECT JSONB_AGG(parsed_content ORDER BY created_at) + SELECT JSONB_AGG(parsed_content) INTO messages_array FROM parsed_messages; @@ -331,5 +360,5 @@ BEGIN END; $$; --- Grant execute permission on the function +-- Grant execute permissions GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, service_role; \ No newline at end of file