Merge pull request #44 from kortix-ai/context-management-v1

Simple version of LLM context management (summarizing)
This commit is contained in:
Adam Cohen Hillel 2025-04-17 00:44:47 +01:00 committed by GitHub
commit bc7662a484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 381 additions and 12 deletions

View File

@ -0,0 +1,298 @@
"""
Context Management for AgentPress Threads.
This module handles token counting and thread summarization to prevent
reaching the context window limitations of LLM models.
"""
import json
from typing import List, Dict, Any, Optional
from litellm import token_counter, completion, completion_cost
from services.supabase import DBConnection
from services.llm import make_llm_api_call
from utils.logger import logger
# Constants for token management
DEFAULT_TOKEN_THRESHOLD = 120000 # 80k tokens threshold for summarization
SUMMARY_TARGET_TOKENS = 10000 # Target ~10k tokens for the summary message
RESERVE_TOKENS = 5000 # Reserve tokens for new messages
class ContextManager:
"""Manages thread context including token counting and summarization."""
def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD):
"""Initialize the ContextManager.
Args:
token_threshold: Token count threshold to trigger summarization
"""
self.db = DBConnection()
self.token_threshold = token_threshold
async def get_thread_token_count(self, thread_id: str) -> int:
"""Get the current token count for a thread using LiteLLM.
Args:
thread_id: ID of the thread to analyze
Returns:
The total token count for relevant messages in the thread
"""
logger.debug(f"Getting token count for thread {thread_id}")
try:
# Get messages for the thread
messages = await self.get_messages_for_summarization(thread_id)
if not messages:
logger.debug(f"No messages found for thread {thread_id}")
return 0
# Use litellm's token_counter for accurate model-specific counting
# This is much more accurate than the SQL-based estimation
token_count = token_counter(model="gpt-4", messages=messages)
logger.info(f"Thread {thread_id} has {token_count} tokens (calculated with litellm)")
return token_count
except Exception as e:
logger.error(f"Error getting token count: {str(e)}")
return 0
async def get_messages_for_summarization(self, thread_id: str) -> List[Dict[str, Any]]:
"""Get all LLM messages from the thread that need to be summarized.
This gets messages after the most recent summary or all messages if
no summary exists. Unlike get_llm_messages, this includes ALL messages
since the last summary, even if we're generating a new summary.
Args:
thread_id: ID of the thread to get messages from
Returns:
List of message objects to summarize
"""
logger.debug(f"Getting messages for summarization for thread {thread_id}")
client = await self.db.client
try:
# Find the most recent summary message
summary_result = await client.table('messages').select('created_at') \
.eq('thread_id', thread_id) \
.eq('type', 'summary') \
.eq('is_llm_message', True) \
.order('created_at', desc=True) \
.limit(1) \
.execute()
# Get messages after the most recent summary or all messages if no summary
if summary_result.data and len(summary_result.data) > 0:
last_summary_time = summary_result.data[0]['created_at']
logger.debug(f"Found last summary at {last_summary_time}")
# Get all messages after the summary, but NOT including the summary itself
messages_result = await client.table('messages').select('*') \
.eq('thread_id', thread_id) \
.eq('is_llm_message', True) \
.gt('created_at', last_summary_time) \
.order('created_at') \
.execute()
else:
logger.debug("No previous summary found, getting all messages")
# Get all messages
messages_result = await client.table('messages').select('*') \
.eq('thread_id', thread_id) \
.eq('is_llm_message', True) \
.order('created_at') \
.execute()
# Parse the message content if needed
messages = []
for msg in messages_result.data:
# Skip existing summary messages - we don't want to summarize summaries
if msg.get('type') == 'summary':
logger.debug(f"Skipping summary message from {msg.get('created_at')}")
continue
# Parse content if it's a string
content = msg['content']
if isinstance(content, str):
try:
content = json.loads(content)
except json.JSONDecodeError:
pass # Keep as string if not valid JSON
# Ensure we have the proper format for the LLM
if 'role' not in content and 'type' in msg:
# Convert message type to role if needed
role = msg['type']
if role == 'assistant' or role == 'user' or role == 'system' or role == 'tool':
content = {'role': role, 'content': content}
messages.append(content)
logger.info(f"Got {len(messages)} messages to summarize for thread {thread_id}")
return messages
except Exception as e:
logger.error(f"Error getting messages for summarization: {str(e)}", exc_info=True)
return []
async def create_summary(
self,
thread_id: str,
messages: List[Dict[str, Any]],
model: str = "gpt-3.5-turbo"
) -> Optional[Dict[str, Any]]:
"""Generate a summary of conversation messages.
Args:
thread_id: ID of the thread to summarize
messages: Messages to summarize
model: LLM model to use for summarization
Returns:
Summary message object or None if summarization failed
"""
if not messages:
logger.warning("No messages to summarize")
return None
logger.info(f"Creating summary for thread {thread_id} with {len(messages)} messages")
# Create system message with summarization instructions
system_message = {
"role": "system",
"content": f"""You are a specialized summarization assistant. Your task is to create a concise but comprehensive summary of the conversation history.
The summary should:
1. Preserve all key information including decisions, conclusions, and important context
2. Include any tools that were used and their results
3. Maintain chronological order of events
4. Be presented as a narrated list of key points with section headers
5. Include only factual information from the conversation (no new information)
6. Be concise but detailed enough that the conversation can continue with this summary as context
VERY IMPORTANT: This summary will replace older parts of the conversation in the LLM's context window, so ensure it contains ALL key information and LATEST STATE OF THE CONVERSATION - SO WE WILL KNOW HOW TO PICK UP WHERE WE LEFT OFF.
THE CONVERSATION HISTORY TO SUMMARIZE IS AS FOLLOWS:
===============================================================
==================== CONVERSATION HISTORY ====================
{messages}
==================== END OF CONVERSATION HISTORY ====================
===============================================================
"""
}
try:
# Call LLM to generate summary
response = await make_llm_api_call(
model_name=model,
messages=[system_message, {"role": "user", "content": "PLEASE PROVIDE THE SUMMARY NOW."}],
temperature=0,
max_tokens=SUMMARY_TARGET_TOKENS,
stream=False
)
if response and hasattr(response, 'choices') and response.choices:
summary_content = response.choices[0].message.content
# Track token usage
try:
token_count = token_counter(model=model, messages=[{"role": "user", "content": summary_content}])
cost = completion_cost(model=model, prompt="", completion=summary_content)
logger.info(f"Summary generated with {token_count} tokens at cost ${cost:.6f}")
except Exception as e:
logger.error(f"Error calculating token usage: {str(e)}")
# Format the summary message with clear beginning and end markers
formatted_summary = f"""
======== CONVERSATION HISTORY SUMMARY ========
{summary_content}
======== END OF SUMMARY ========
The above is a summary of the conversation history. The conversation continues below.
"""
# Format the summary message
summary_message = {
"role": "user",
"content": formatted_summary
}
return summary_message
else:
logger.error("Failed to generate summary: Invalid response")
return None
except Exception as e:
logger.error(f"Error creating summary: {str(e)}", exc_info=True)
return None
async def check_and_summarize_if_needed(
self,
thread_id: str,
add_message_callback,
model: str = "gpt-3.5-turbo",
force: bool = False
) -> bool:
"""Check if thread needs summarization and summarize if so.
Args:
thread_id: ID of the thread to check
add_message_callback: Callback to add the summary message to the thread
model: LLM model to use for summarization
force: Whether to force summarization regardless of token count
Returns:
True if summarization was performed, False otherwise
"""
try:
# Get token count using LiteLLM (accurate model-specific counting)
token_count = await self.get_thread_token_count(thread_id)
# If token count is below threshold and not forcing, no summarization needed
if token_count < self.token_threshold and not force:
logger.debug(f"Thread {thread_id} has {token_count} tokens, below threshold {self.token_threshold}")
return False
# Log reason for summarization
if force:
logger.info(f"Forced summarization of thread {thread_id} with {token_count} tokens")
else:
logger.info(f"Thread {thread_id} exceeds token threshold ({token_count} >= {self.token_threshold}), summarizing...")
# Get messages to summarize
messages = await self.get_messages_for_summarization(thread_id)
# If there are too few messages, don't summarize
if len(messages) < 3:
logger.info(f"Thread {thread_id} has too few messages ({len(messages)}) to summarize")
return False
# Create summary
summary = await self.create_summary(thread_id, messages, model)
if summary:
# Add summary message to thread
await add_message_callback(
thread_id=thread_id,
type="summary",
content=summary,
is_llm_message=True,
metadata={"token_count": token_count}
)
logger.info(f"Successfully added summary to thread {thread_id}")
return True
else:
logger.error(f"Failed to create summary for thread {thread_id}")
return False
except Exception as e:
logger.error(f"Error in check_and_summarize_if_needed: {str(e)}", exc_info=True)
return False

View File

@ -7,14 +7,15 @@ This module provides comprehensive conversation management, including:
- Tool registration and execution - Tool registration and execution
- LLM interaction with streaming support - LLM interaction with streaming support
- Error handling and cleanup - Error handling and cleanup
- Context summarization to manage token limits
""" """
import json import json
import uuid from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Tuple, Callable, Literal
from services.llm import make_llm_api_call from services.llm import make_llm_api_call
from agentpress.tool import Tool, ToolResult from agentpress.tool import Tool
from agentpress.tool_registry import ToolRegistry from agentpress.tool_registry import ToolRegistry
from agentpress.context_manager import ContextManager
from agentpress.response_processor import ( from agentpress.response_processor import (
ResponseProcessor, ResponseProcessor,
ProcessorConfig ProcessorConfig
@ -34,13 +35,16 @@ class ThreadManager:
""" """
def __init__(self): def __init__(self):
"""Initialize ThreadManager.""" """Initialize ThreadManager.
"""
self.db = DBConnection() self.db = DBConnection()
self.tool_registry = ToolRegistry() self.tool_registry = ToolRegistry()
self.response_processor = ResponseProcessor( self.response_processor = ResponseProcessor(
tool_registry=self.tool_registry, tool_registry=self.tool_registry,
add_message_callback=self.add_message add_message_callback=self.add_message
) )
self.context_manager = ContextManager()
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager.""" """Add a tool to the ThreadManager."""
@ -88,6 +92,9 @@ class ThreadManager:
async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]: async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]:
"""Get all messages for a thread. """Get all messages for a thread.
This method uses the SQL function which handles context truncation
by considering summary messages.
Args: Args:
thread_id: The ID of the thread to get messages for. thread_id: The ID of the thread to get messages for.
@ -220,7 +227,41 @@ Here are the XML tools available with examples:
# 1. Get messages from thread for LLM call # 1. Get messages from thread for LLM call
messages = await self.get_llm_messages(thread_id) messages = await self.get_llm_messages(thread_id)
# 2. Prepare messages for LLM call + add temporary message if it exists # 2. Check token count before proceeding
# Use litellm to count tokens in the messages
token_count = 0
try:
from litellm import token_counter
token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
token_threshold = self.context_manager.token_threshold
logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)")
# If we're over the threshold, summarize the thread
if token_count >= token_threshold:
logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...")
# Create summary using context manager
summarized = await self.context_manager.check_and_summarize_if_needed(
thread_id=thread_id,
add_message_callback=self.add_message,
model=llm_model,
force=True # Force summarization
)
if summarized:
# If summarization was successful, get the updated messages
# This will now include the summary message and only messages after it
logger.info("Summarization complete, fetching updated messages with summary")
messages = await self.get_llm_messages(thread_id)
# Recount tokens after summarization
new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}")
else:
logger.warning("Summarization failed or wasn't needed - proceeding with original messages")
except Exception as e:
logger.error(f"Error counting tokens or summarizing: {str(e)}")
# 3. Prepare messages for LLM call + add temporary message if it exists
prepared_messages = [system_prompt] prepared_messages = [system_prompt]
# Find the last user message index # Find the last user message index
@ -242,19 +283,19 @@ Here are the XML tools available with examples:
prepared_messages.append(temp_msg) prepared_messages.append(temp_msg)
logger.debug("Added temporary message to the end of prepared messages") logger.debug("Added temporary message to the end of prepared messages")
# 3. Create or use processor config - this is now redundant since we handle it above # 4. Create or use processor config - this is now redundant since we handle it above
# but kept for consistency and clarity # but kept for consistency and clarity
logger.debug(f"Processor config: XML={processor_config.xml_tool_calling}, Native={processor_config.native_tool_calling}, " logger.debug(f"Processor config: XML={processor_config.xml_tool_calling}, Native={processor_config.native_tool_calling}, "
f"Execute tools={processor_config.execute_tools}, Strategy={processor_config.tool_execution_strategy}, " f"Execute tools={processor_config.execute_tools}, Strategy={processor_config.tool_execution_strategy}, "
f"XML limit={processor_config.max_xml_tool_calls}") f"XML limit={processor_config.max_xml_tool_calls}")
# 4. Prepare tools for LLM call # 5. Prepare tools for LLM call
openapi_tool_schemas = None openapi_tool_schemas = None
if processor_config.native_tool_calling: if processor_config.native_tool_calling:
openapi_tool_schemas = self.tool_registry.get_openapi_schemas() openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas") logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
# 5. Make LLM API call # 6. Make LLM API call
logger.debug("Making LLM API call") logger.debug("Making LLM API call")
try: try:
llm_response = await make_llm_api_call( llm_response = await make_llm_api_call(
@ -272,7 +313,7 @@ Here are the XML tools available with examples:
logger.error(f"Failed to make LLM API call: {str(e)}", exc_info=True) logger.error(f"Failed to make LLM API call: {str(e)}", exc_info=True)
raise raise
# 6. Process LLM response using the ResponseProcessor # 7. Process LLM response using the ResponseProcessor
if stream: if stream:
logger.debug("Processing streaming response") logger.debug("Processing streaming response")
response_generator = self.response_processor.process_streaming_response( response_generator = self.response_processor.process_streaming_response(

View File

@ -284,6 +284,8 @@ DECLARE
messages_array JSONB := '[]'::JSONB; messages_array JSONB := '[]'::JSONB;
has_access BOOLEAN; has_access BOOLEAN;
current_role TEXT; current_role TEXT;
latest_summary_id UUID;
latest_summary_time TIMESTAMP WITH TIME ZONE;
BEGIN BEGIN
-- Get current role -- Get current role
SELECT current_user INTO current_role; SELECT current_user INTO current_role;
@ -306,19 +308,46 @@ BEGIN
END IF; END IF;
END IF; END IF;
-- Find the latest summary message if it exists
SELECT message_id, created_at
INTO latest_summary_id, latest_summary_time
FROM messages
WHERE thread_id = p_thread_id
AND type = 'summary'
AND is_llm_message = TRUE
ORDER BY created_at DESC
LIMIT 1;
-- Log whether a summary was found (helpful for debugging)
IF latest_summary_id IS NOT NULL THEN
RAISE NOTICE 'Found latest summary message: id=%, time=%', latest_summary_id, latest_summary_time;
ELSE
RAISE NOTICE 'No summary message found for thread %', p_thread_id;
END IF;
-- Parse content if it's stored as a string and return proper JSON objects -- Parse content if it's stored as a string and return proper JSON objects
WITH parsed_messages AS ( WITH parsed_messages AS (
SELECT SELECT
message_id,
CASE CASE
WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
ELSE content ELSE content
END AS parsed_content, END AS parsed_content,
created_at created_at,
type
FROM messages FROM messages
WHERE thread_id = p_thread_id WHERE thread_id = p_thread_id
AND is_llm_message = TRUE AND is_llm_message = TRUE
AND (
-- Include the latest summary and all messages after it,
-- or all messages if no summary exists
latest_summary_id IS NULL
OR message_id = latest_summary_id
OR created_at > latest_summary_time
)
ORDER BY created_at
) )
SELECT JSONB_AGG(parsed_content ORDER BY created_at) SELECT JSONB_AGG(parsed_content)
INTO messages_array INTO messages_array
FROM parsed_messages; FROM parsed_messages;
@ -331,5 +360,5 @@ BEGIN
END; END;
$$; $$;
-- Grant execute permission on the function -- Grant execute permissions
GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, service_role; GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, service_role;

View File

@ -393,6 +393,7 @@ export const getMessages = async (threadId: string): Promise<Message[]> => {
.select('*') .select('*')
.eq('thread_id', threadId) .eq('thread_id', threadId)
.neq('type', 'cost') .neq('type', 'cost')
.neq('type', 'summary')
.order('created_at', { ascending: true }); .order('created_at', { ascending: true });
if (error) { if (error) {