2025-04-17 07:16:53 +08:00
"""
Context Management for AgentPress Threads .
This module handles token counting and thread summarization to prevent
reaching the context window limitations of LLM models .
"""
import json
from typing import List , Dict , Any , Optional
from litellm import token_counter , completion , completion_cost
from services . supabase import DBConnection
from services . llm import make_llm_api_call
from utils . logger import logger
# Constants for token management
DEFAULT_TOKEN_THRESHOLD = 120000 # 80k tokens threshold for summarization
SUMMARY_TARGET_TOKENS = 10000 # Target ~10k tokens for the summary message
RESERVE_TOKENS = 5000 # Reserve tokens for new messages
class ContextManager :
""" Manages thread context including token counting and summarization. """
def __init__ ( self , token_threshold : int = DEFAULT_TOKEN_THRESHOLD ) :
""" Initialize the ContextManager.
Args :
token_threshold : Token count threshold to trigger summarization
"""
self . db = DBConnection ( )
self . token_threshold = token_threshold
async def get_thread_token_count ( self , thread_id : str ) - > int :
""" Get the current token count for a thread using LiteLLM.
Args :
thread_id : ID of the thread to analyze
Returns :
The total token count for relevant messages in the thread
"""
logger . debug ( f " Getting token count for thread { thread_id } " )
try :
# Get messages for the thread
messages = await self . get_messages_for_summarization ( thread_id )
if not messages :
logger . debug ( f " No messages found for thread { thread_id } " )
return 0
# Use litellm's token_counter for accurate model-specific counting
# This is much more accurate than the SQL-based estimation
token_count = token_counter ( model = " gpt-4 " , messages = messages )
logger . info ( f " Thread { thread_id } has { token_count } tokens (calculated with litellm) " )
return token_count
except Exception as e :
logger . error ( f " Error getting token count: { str ( e ) } " )
return 0
async def get_messages_for_summarization ( self , thread_id : str ) - > List [ Dict [ str , Any ] ] :
""" Get all LLM messages from the thread that need to be summarized.
This gets messages after the most recent summary or all messages if
no summary exists . Unlike get_llm_messages , this includes ALL messages
since the last summary , even if we ' re generating a new summary.
Args :
thread_id : ID of the thread to get messages from
Returns :
List of message objects to summarize
"""
logger . debug ( f " Getting messages for summarization for thread { thread_id } " )
client = await self . db . client
try :
# Find the most recent summary message
summary_result = await client . table ( ' messages ' ) . select ( ' created_at ' ) \
. eq ( ' thread_id ' , thread_id ) \
. eq ( ' type ' , ' summary ' ) \
. eq ( ' is_llm_message ' , True ) \
. order ( ' created_at ' , desc = True ) \
. limit ( 1 ) \
. execute ( )
# Get messages after the most recent summary or all messages if no summary
if summary_result . data and len ( summary_result . data ) > 0 :
last_summary_time = summary_result . data [ 0 ] [ ' created_at ' ]
logger . debug ( f " Found last summary at { last_summary_time } " )
# Get all messages after the summary, but NOT including the summary itself
messages_result = await client . table ( ' messages ' ) . select ( ' * ' ) \
. eq ( ' thread_id ' , thread_id ) \
. eq ( ' is_llm_message ' , True ) \
. gt ( ' created_at ' , last_summary_time ) \
. order ( ' created_at ' ) \
. execute ( )
else :
logger . debug ( " No previous summary found, getting all messages " )
# Get all messages
messages_result = await client . table ( ' messages ' ) . select ( ' * ' ) \
. eq ( ' thread_id ' , thread_id ) \
. eq ( ' is_llm_message ' , True ) \
. order ( ' created_at ' ) \
. execute ( )
# Parse the message content if needed
messages = [ ]
for msg in messages_result . data :
# Skip existing summary messages - we don't want to summarize summaries
if msg . get ( ' type ' ) == ' summary ' :
logger . debug ( f " Skipping summary message from { msg . get ( ' created_at ' ) } " )
continue
# Parse content if it's a string
content = msg [ ' content ' ]
if isinstance ( content , str ) :
try :
content = json . loads ( content )
except json . JSONDecodeError :
pass # Keep as string if not valid JSON
# Ensure we have the proper format for the LLM
if ' role ' not in content and ' type ' in msg :
# Convert message type to role if needed
role = msg [ ' type ' ]
if role == ' assistant ' or role == ' user ' or role == ' system ' or role == ' tool ' :
content = { ' role ' : role , ' content ' : content }
messages . append ( content )
logger . info ( f " Got { len ( messages ) } messages to summarize for thread { thread_id } " )
return messages
except Exception as e :
logger . error ( f " Error getting messages for summarization: { str ( e ) } " , exc_info = True )
return [ ]
async def create_summary (
self ,
thread_id : str ,
messages : List [ Dict [ str , Any ] ] ,
2025-04-18 06:17:48 +08:00
model : str = " gpt-4o-mini "
2025-04-17 07:16:53 +08:00
) - > Optional [ Dict [ str , Any ] ] :
""" Generate a summary of conversation messages.
Args :
thread_id : ID of the thread to summarize
messages : Messages to summarize
model : LLM model to use for summarization
Returns :
Summary message object or None if summarization failed
"""
if not messages :
logger . warning ( " No messages to summarize " )
return None
logger . info ( f " Creating summary for thread { thread_id } with { len ( messages ) } messages " )
# Create system message with summarization instructions
system_message = {
" role " : " system " ,
" content " : f """ You are a specialized summarization assistant. Your task is to create a concise but comprehensive summary of the conversation history.
The summary should :
1. Preserve all key information including decisions , conclusions , and important context
2. Include any tools that were used and their results
3. Maintain chronological order of events
4. Be presented as a narrated list of key points with section headers
5. Include only factual information from the conversation ( no new information )
6. Be concise but detailed enough that the conversation can continue with this summary as context
VERY IMPORTANT : This summary will replace older parts of the conversation in the LLM ' s context window, so ensure it contains ALL key information and LATEST STATE OF THE CONVERSATION - SO WE WILL KNOW HOW TO PICK UP WHERE WE LEFT OFF.
THE CONVERSATION HISTORY TO SUMMARIZE IS AS FOLLOWS :
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
== == == == == == == == == == CONVERSATION HISTORY == == == == == == == == == ==
{ messages }
== == == == == == == == == == END OF CONVERSATION HISTORY == == == == == == == == == ==
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
"""
}
try :
# Call LLM to generate summary
response = await make_llm_api_call (
model_name = model ,
messages = [ system_message , { " role " : " user " , " content " : " PLEASE PROVIDE THE SUMMARY NOW. " } ] ,
temperature = 0 ,
max_tokens = SUMMARY_TARGET_TOKENS ,
stream = False
)
if response and hasattr ( response , ' choices ' ) and response . choices :
summary_content = response . choices [ 0 ] . message . content
# Track token usage
try :
token_count = token_counter ( model = model , messages = [ { " role " : " user " , " content " : summary_content } ] )
cost = completion_cost ( model = model , prompt = " " , completion = summary_content )
logger . info ( f " Summary generated with { token_count } tokens at cost $ { cost : .6f } " )
except Exception as e :
logger . error ( f " Error calculating token usage: { str ( e ) } " )
# Format the summary message with clear beginning and end markers
formatted_summary = f """
== == == == CONVERSATION HISTORY SUMMARY == == == ==
{ summary_content }
== == == == END OF SUMMARY == == == ==
The above is a summary of the conversation history . The conversation continues below .
"""
# Format the summary message
summary_message = {
" role " : " user " ,
" content " : formatted_summary
}
return summary_message
else :
logger . error ( " Failed to generate summary: Invalid response " )
return None
except Exception as e :
logger . error ( f " Error creating summary: { str ( e ) } " , exc_info = True )
return None
async def check_and_summarize_if_needed (
self ,
thread_id : str ,
add_message_callback ,
2025-04-18 06:17:48 +08:00
model : str = " gpt-4o-mini " ,
2025-04-17 07:16:53 +08:00
force : bool = False
) - > bool :
""" Check if thread needs summarization and summarize if so.
Args :
thread_id : ID of the thread to check
add_message_callback : Callback to add the summary message to the thread
model : LLM model to use for summarization
force : Whether to force summarization regardless of token count
Returns :
True if summarization was performed , False otherwise
"""
try :
# Get token count using LiteLLM (accurate model-specific counting)
token_count = await self . get_thread_token_count ( thread_id )
# If token count is below threshold and not forcing, no summarization needed
if token_count < self . token_threshold and not force :
logger . debug ( f " Thread { thread_id } has { token_count } tokens, below threshold { self . token_threshold } " )
return False
# Log reason for summarization
if force :
logger . info ( f " Forced summarization of thread { thread_id } with { token_count } tokens " )
else :
logger . info ( f " Thread { thread_id } exceeds token threshold ( { token_count } >= { self . token_threshold } ), summarizing... " )
# Get messages to summarize
messages = await self . get_messages_for_summarization ( thread_id )
# If there are too few messages, don't summarize
if len ( messages ) < 3 :
logger . info ( f " Thread { thread_id } has too few messages ( { len ( messages ) } ) to summarize " )
return False
# Create summary
summary = await self . create_summary ( thread_id , messages , model )
if summary :
# Add summary message to thread
await add_message_callback (
thread_id = thread_id ,
type = " summary " ,
content = summary ,
is_llm_message = True ,
metadata = { " token_count " : token_count }
)
logger . info ( f " Successfully added summary to thread { thread_id } " )
return True
else :
logger . error ( f " Failed to create summary for thread { thread_id } " )
return False
except Exception as e :
logger . error ( f " Error in check_and_summarize_if_needed: { str ( e ) } " , exc_info = True )
return False