mirror of https://github.com/kortix-ai/suna.git
Merge pull request #44 from kortix-ai/context-management-v1
Simple version of LLM context management (summarizing)
This commit is contained in:
commit
bc7662a484
|
@ -0,0 +1,298 @@
|
|||
"""
|
||||
Context Management for AgentPress Threads.
|
||||
|
||||
This module handles token counting and thread summarization to prevent
|
||||
reaching the context window limitations of LLM models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from litellm import token_counter, completion, completion_cost
|
||||
from services.supabase import DBConnection
|
||||
from services.llm import make_llm_api_call
|
||||
from utils.logger import logger
|
||||
|
||||
# Constants for token management
|
||||
DEFAULT_TOKEN_THRESHOLD = 120000 # 80k tokens threshold for summarization
|
||||
SUMMARY_TARGET_TOKENS = 10000 # Target ~10k tokens for the summary message
|
||||
RESERVE_TOKENS = 5000 # Reserve tokens for new messages
|
||||
|
||||
class ContextManager:
|
||||
"""Manages thread context including token counting and summarization."""
|
||||
|
||||
def __init__(self, token_threshold: int = DEFAULT_TOKEN_THRESHOLD):
|
||||
"""Initialize the ContextManager.
|
||||
|
||||
Args:
|
||||
token_threshold: Token count threshold to trigger summarization
|
||||
"""
|
||||
self.db = DBConnection()
|
||||
self.token_threshold = token_threshold
|
||||
|
||||
async def get_thread_token_count(self, thread_id: str) -> int:
|
||||
"""Get the current token count for a thread using LiteLLM.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to analyze
|
||||
|
||||
Returns:
|
||||
The total token count for relevant messages in the thread
|
||||
"""
|
||||
logger.debug(f"Getting token count for thread {thread_id}")
|
||||
|
||||
try:
|
||||
# Get messages for the thread
|
||||
messages = await self.get_messages_for_summarization(thread_id)
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"No messages found for thread {thread_id}")
|
||||
return 0
|
||||
|
||||
# Use litellm's token_counter for accurate model-specific counting
|
||||
# This is much more accurate than the SQL-based estimation
|
||||
token_count = token_counter(model="gpt-4", messages=messages)
|
||||
|
||||
logger.info(f"Thread {thread_id} has {token_count} tokens (calculated with litellm)")
|
||||
return token_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token count: {str(e)}")
|
||||
return 0
|
||||
|
||||
async def get_messages_for_summarization(self, thread_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all LLM messages from the thread that need to be summarized.
|
||||
|
||||
This gets messages after the most recent summary or all messages if
|
||||
no summary exists. Unlike get_llm_messages, this includes ALL messages
|
||||
since the last summary, even if we're generating a new summary.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to get messages from
|
||||
|
||||
Returns:
|
||||
List of message objects to summarize
|
||||
"""
|
||||
logger.debug(f"Getting messages for summarization for thread {thread_id}")
|
||||
client = await self.db.client
|
||||
|
||||
try:
|
||||
# Find the most recent summary message
|
||||
summary_result = await client.table('messages').select('created_at') \
|
||||
.eq('thread_id', thread_id) \
|
||||
.eq('type', 'summary') \
|
||||
.eq('is_llm_message', True) \
|
||||
.order('created_at', desc=True) \
|
||||
.limit(1) \
|
||||
.execute()
|
||||
|
||||
# Get messages after the most recent summary or all messages if no summary
|
||||
if summary_result.data and len(summary_result.data) > 0:
|
||||
last_summary_time = summary_result.data[0]['created_at']
|
||||
logger.debug(f"Found last summary at {last_summary_time}")
|
||||
|
||||
# Get all messages after the summary, but NOT including the summary itself
|
||||
messages_result = await client.table('messages').select('*') \
|
||||
.eq('thread_id', thread_id) \
|
||||
.eq('is_llm_message', True) \
|
||||
.gt('created_at', last_summary_time) \
|
||||
.order('created_at') \
|
||||
.execute()
|
||||
else:
|
||||
logger.debug("No previous summary found, getting all messages")
|
||||
# Get all messages
|
||||
messages_result = await client.table('messages').select('*') \
|
||||
.eq('thread_id', thread_id) \
|
||||
.eq('is_llm_message', True) \
|
||||
.order('created_at') \
|
||||
.execute()
|
||||
|
||||
# Parse the message content if needed
|
||||
messages = []
|
||||
for msg in messages_result.data:
|
||||
# Skip existing summary messages - we don't want to summarize summaries
|
||||
if msg.get('type') == 'summary':
|
||||
logger.debug(f"Skipping summary message from {msg.get('created_at')}")
|
||||
continue
|
||||
|
||||
# Parse content if it's a string
|
||||
content = msg['content']
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep as string if not valid JSON
|
||||
|
||||
# Ensure we have the proper format for the LLM
|
||||
if 'role' not in content and 'type' in msg:
|
||||
# Convert message type to role if needed
|
||||
role = msg['type']
|
||||
if role == 'assistant' or role == 'user' or role == 'system' or role == 'tool':
|
||||
content = {'role': role, 'content': content}
|
||||
|
||||
messages.append(content)
|
||||
|
||||
logger.info(f"Got {len(messages)} messages to summarize for thread {thread_id}")
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting messages for summarization: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def create_summary(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str = "gpt-3.5-turbo"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generate a summary of conversation messages.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to summarize
|
||||
messages: Messages to summarize
|
||||
model: LLM model to use for summarization
|
||||
|
||||
Returns:
|
||||
Summary message object or None if summarization failed
|
||||
"""
|
||||
if not messages:
|
||||
logger.warning("No messages to summarize")
|
||||
return None
|
||||
|
||||
logger.info(f"Creating summary for thread {thread_id} with {len(messages)} messages")
|
||||
|
||||
# Create system message with summarization instructions
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": f"""You are a specialized summarization assistant. Your task is to create a concise but comprehensive summary of the conversation history.
|
||||
|
||||
The summary should:
|
||||
1. Preserve all key information including decisions, conclusions, and important context
|
||||
2. Include any tools that were used and their results
|
||||
3. Maintain chronological order of events
|
||||
4. Be presented as a narrated list of key points with section headers
|
||||
5. Include only factual information from the conversation (no new information)
|
||||
6. Be concise but detailed enough that the conversation can continue with this summary as context
|
||||
|
||||
VERY IMPORTANT: This summary will replace older parts of the conversation in the LLM's context window, so ensure it contains ALL key information and LATEST STATE OF THE CONVERSATION - SO WE WILL KNOW HOW TO PICK UP WHERE WE LEFT OFF.
|
||||
|
||||
|
||||
THE CONVERSATION HISTORY TO SUMMARIZE IS AS FOLLOWS:
|
||||
===============================================================
|
||||
==================== CONVERSATION HISTORY ====================
|
||||
{messages}
|
||||
==================== END OF CONVERSATION HISTORY ====================
|
||||
===============================================================
|
||||
"""
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM to generate summary
|
||||
response = await make_llm_api_call(
|
||||
model_name=model,
|
||||
messages=[system_message, {"role": "user", "content": "PLEASE PROVIDE THE SUMMARY NOW."}],
|
||||
temperature=0,
|
||||
max_tokens=SUMMARY_TARGET_TOKENS,
|
||||
stream=False
|
||||
)
|
||||
|
||||
if response and hasattr(response, 'choices') and response.choices:
|
||||
summary_content = response.choices[0].message.content
|
||||
|
||||
# Track token usage
|
||||
try:
|
||||
token_count = token_counter(model=model, messages=[{"role": "user", "content": summary_content}])
|
||||
cost = completion_cost(model=model, prompt="", completion=summary_content)
|
||||
logger.info(f"Summary generated with {token_count} tokens at cost ${cost:.6f}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating token usage: {str(e)}")
|
||||
|
||||
# Format the summary message with clear beginning and end markers
|
||||
formatted_summary = f"""
|
||||
======== CONVERSATION HISTORY SUMMARY ========
|
||||
|
||||
{summary_content}
|
||||
|
||||
======== END OF SUMMARY ========
|
||||
|
||||
The above is a summary of the conversation history. The conversation continues below.
|
||||
"""
|
||||
|
||||
# Format the summary message
|
||||
summary_message = {
|
||||
"role": "user",
|
||||
"content": formatted_summary
|
||||
}
|
||||
|
||||
return summary_message
|
||||
else:
|
||||
logger.error("Failed to generate summary: Invalid response")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating summary: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def check_and_summarize_if_needed(
|
||||
self,
|
||||
thread_id: str,
|
||||
add_message_callback,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
force: bool = False
|
||||
) -> bool:
|
||||
"""Check if thread needs summarization and summarize if so.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread to check
|
||||
add_message_callback: Callback to add the summary message to the thread
|
||||
model: LLM model to use for summarization
|
||||
force: Whether to force summarization regardless of token count
|
||||
|
||||
Returns:
|
||||
True if summarization was performed, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Get token count using LiteLLM (accurate model-specific counting)
|
||||
token_count = await self.get_thread_token_count(thread_id)
|
||||
|
||||
# If token count is below threshold and not forcing, no summarization needed
|
||||
if token_count < self.token_threshold and not force:
|
||||
logger.debug(f"Thread {thread_id} has {token_count} tokens, below threshold {self.token_threshold}")
|
||||
return False
|
||||
|
||||
# Log reason for summarization
|
||||
if force:
|
||||
logger.info(f"Forced summarization of thread {thread_id} with {token_count} tokens")
|
||||
else:
|
||||
logger.info(f"Thread {thread_id} exceeds token threshold ({token_count} >= {self.token_threshold}), summarizing...")
|
||||
|
||||
# Get messages to summarize
|
||||
messages = await self.get_messages_for_summarization(thread_id)
|
||||
|
||||
# If there are too few messages, don't summarize
|
||||
if len(messages) < 3:
|
||||
logger.info(f"Thread {thread_id} has too few messages ({len(messages)}) to summarize")
|
||||
return False
|
||||
|
||||
# Create summary
|
||||
summary = await self.create_summary(thread_id, messages, model)
|
||||
|
||||
if summary:
|
||||
# Add summary message to thread
|
||||
await add_message_callback(
|
||||
thread_id=thread_id,
|
||||
type="summary",
|
||||
content=summary,
|
||||
is_llm_message=True,
|
||||
metadata={"token_count": token_count}
|
||||
)
|
||||
|
||||
logger.info(f"Successfully added summary to thread {thread_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to create summary for thread {thread_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in check_and_summarize_if_needed: {str(e)}", exc_info=True)
|
||||
return False
|
|
@ -7,14 +7,15 @@ This module provides comprehensive conversation management, including:
|
|||
- Tool registration and execution
|
||||
- LLM interaction with streaming support
|
||||
- Error handling and cleanup
|
||||
- Context summarization to manage token limits
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Tuple, Callable, Literal
|
||||
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator, Literal
|
||||
from services.llm import make_llm_api_call
|
||||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool import Tool
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.context_manager import ContextManager
|
||||
from agentpress.response_processor import (
|
||||
ResponseProcessor,
|
||||
ProcessorConfig
|
||||
|
@ -34,13 +35,16 @@ class ThreadManager:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ThreadManager."""
|
||||
"""Initialize ThreadManager.
|
||||
|
||||
"""
|
||||
self.db = DBConnection()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.response_processor = ResponseProcessor(
|
||||
tool_registry=self.tool_registry,
|
||||
add_message_callback=self.add_message
|
||||
)
|
||||
self.context_manager = ContextManager()
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
"""Add a tool to the ThreadManager."""
|
||||
|
@ -88,6 +92,9 @@ class ThreadManager:
|
|||
async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all messages for a thread.
|
||||
|
||||
This method uses the SQL function which handles context truncation
|
||||
by considering summary messages.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread to get messages for.
|
||||
|
||||
|
@ -220,7 +227,41 @@ Here are the XML tools available with examples:
|
|||
# 1. Get messages from thread for LLM call
|
||||
messages = await self.get_llm_messages(thread_id)
|
||||
|
||||
# 2. Prepare messages for LLM call + add temporary message if it exists
|
||||
# 2. Check token count before proceeding
|
||||
# Use litellm to count tokens in the messages
|
||||
token_count = 0
|
||||
try:
|
||||
from litellm import token_counter
|
||||
token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
|
||||
token_threshold = self.context_manager.token_threshold
|
||||
logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)")
|
||||
|
||||
# If we're over the threshold, summarize the thread
|
||||
if token_count >= token_threshold:
|
||||
logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...")
|
||||
|
||||
# Create summary using context manager
|
||||
summarized = await self.context_manager.check_and_summarize_if_needed(
|
||||
thread_id=thread_id,
|
||||
add_message_callback=self.add_message,
|
||||
model=llm_model,
|
||||
force=True # Force summarization
|
||||
)
|
||||
|
||||
if summarized:
|
||||
# If summarization was successful, get the updated messages
|
||||
# This will now include the summary message and only messages after it
|
||||
logger.info("Summarization complete, fetching updated messages with summary")
|
||||
messages = await self.get_llm_messages(thread_id)
|
||||
# Recount tokens after summarization
|
||||
new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
|
||||
logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}")
|
||||
else:
|
||||
logger.warning("Summarization failed or wasn't needed - proceeding with original messages")
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens or summarizing: {str(e)}")
|
||||
|
||||
# 3. Prepare messages for LLM call + add temporary message if it exists
|
||||
prepared_messages = [system_prompt]
|
||||
|
||||
# Find the last user message index
|
||||
|
@ -242,19 +283,19 @@ Here are the XML tools available with examples:
|
|||
prepared_messages.append(temp_msg)
|
||||
logger.debug("Added temporary message to the end of prepared messages")
|
||||
|
||||
# 3. Create or use processor config - this is now redundant since we handle it above
|
||||
# 4. Create or use processor config - this is now redundant since we handle it above
|
||||
# but kept for consistency and clarity
|
||||
logger.debug(f"Processor config: XML={processor_config.xml_tool_calling}, Native={processor_config.native_tool_calling}, "
|
||||
f"Execute tools={processor_config.execute_tools}, Strategy={processor_config.tool_execution_strategy}, "
|
||||
f"XML limit={processor_config.max_xml_tool_calls}")
|
||||
|
||||
# 4. Prepare tools for LLM call
|
||||
# 5. Prepare tools for LLM call
|
||||
openapi_tool_schemas = None
|
||||
if processor_config.native_tool_calling:
|
||||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
# 5. Make LLM API call
|
||||
# 6. Make LLM API call
|
||||
logger.debug("Making LLM API call")
|
||||
try:
|
||||
llm_response = await make_llm_api_call(
|
||||
|
@ -272,7 +313,7 @@ Here are the XML tools available with examples:
|
|||
logger.error(f"Failed to make LLM API call: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 6. Process LLM response using the ResponseProcessor
|
||||
# 7. Process LLM response using the ResponseProcessor
|
||||
if stream:
|
||||
logger.debug("Processing streaming response")
|
||||
response_generator = self.response_processor.process_streaming_response(
|
||||
|
|
|
@ -284,6 +284,8 @@ DECLARE
|
|||
messages_array JSONB := '[]'::JSONB;
|
||||
has_access BOOLEAN;
|
||||
current_role TEXT;
|
||||
latest_summary_id UUID;
|
||||
latest_summary_time TIMESTAMP WITH TIME ZONE;
|
||||
BEGIN
|
||||
-- Get current role
|
||||
SELECT current_user INTO current_role;
|
||||
|
@ -306,19 +308,46 @@ BEGIN
|
|||
END IF;
|
||||
END IF;
|
||||
|
||||
-- Find the latest summary message if it exists
|
||||
SELECT message_id, created_at
|
||||
INTO latest_summary_id, latest_summary_time
|
||||
FROM messages
|
||||
WHERE thread_id = p_thread_id
|
||||
AND type = 'summary'
|
||||
AND is_llm_message = TRUE
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- Log whether a summary was found (helpful for debugging)
|
||||
IF latest_summary_id IS NOT NULL THEN
|
||||
RAISE NOTICE 'Found latest summary message: id=%, time=%', latest_summary_id, latest_summary_time;
|
||||
ELSE
|
||||
RAISE NOTICE 'No summary message found for thread %', p_thread_id;
|
||||
END IF;
|
||||
|
||||
-- Parse content if it's stored as a string and return proper JSON objects
|
||||
WITH parsed_messages AS (
|
||||
SELECT
|
||||
message_id,
|
||||
CASE
|
||||
WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
|
||||
ELSE content
|
||||
END AS parsed_content,
|
||||
created_at
|
||||
created_at,
|
||||
type
|
||||
FROM messages
|
||||
WHERE thread_id = p_thread_id
|
||||
AND is_llm_message = TRUE
|
||||
AND (
|
||||
-- Include the latest summary and all messages after it,
|
||||
-- or all messages if no summary exists
|
||||
latest_summary_id IS NULL
|
||||
OR message_id = latest_summary_id
|
||||
OR created_at > latest_summary_time
|
||||
)
|
||||
ORDER BY created_at
|
||||
)
|
||||
SELECT JSONB_AGG(parsed_content ORDER BY created_at)
|
||||
SELECT JSONB_AGG(parsed_content)
|
||||
INTO messages_array
|
||||
FROM parsed_messages;
|
||||
|
||||
|
@ -331,5 +360,5 @@ BEGIN
|
|||
END;
|
||||
$$;
|
||||
|
||||
-- Grant execute permission on the function
|
||||
-- Grant execute permissions
|
||||
GRANT EXECUTE ON FUNCTION get_llm_formatted_messages TO authenticated, service_role;
|
|
@ -393,6 +393,7 @@ export const getMessages = async (threadId: string): Promise<Message[]> => {
|
|||
.select('*')
|
||||
.eq('thread_id', threadId)
|
||||
.neq('type', 'cost')
|
||||
.neq('type', 'summary')
|
||||
.order('created_at', { ascending: true });
|
||||
|
||||
if (error) {
|
||||
|
|
Loading…
Reference in New Issue