mirror of https://github.com/kortix-ai/suna.git
482 lines
21 KiB
Python
482 lines
21 KiB
Python
"""
|
|
Mathematically optimized prompt caching system for AgentPress.
|
|
|
|
Implements adaptive token-based caching with dynamic threshold calculation:
|
|
|
|
Mathematical Optimization:
|
|
- Auto-detects context window from model registry (200k-1M+ tokens)
|
|
- Calculates optimal cache thresholds using multi-factor formula
|
|
- Adapts to conversation stage, context size, and token density
|
|
- Prevents cache block preoccupation in large context windows
|
|
|
|
Dynamic Thresholds (scales with conversation length):
|
|
- 200k context: 1.5k (≤20 msgs) → 3k (≤100 msgs) → 5k (≤500 msgs) → 9k (500+ msgs)
|
|
- 1M context: 7.5k (≤20 msgs) → 15k (≤100 msgs) → 25k (≤500 msgs) → 45k (500+ msgs)
|
|
- 2M context: 15k (≤20 msgs) → 30k (≤100 msgs) → 50k (≤500 msgs) → 90k (500+ msgs)
|
|
- Adjusts for high/low token density conversations
|
|
- Enforces bounds: min 1024 tokens, max 15% of context
|
|
|
|
Technical Features:
|
|
- Accurate token counting using LiteLLM's model-specific tokenizers
|
|
- Strategic 4-block distribution with TTL optimization
|
|
- Fixed-size chunks prevent cache invalidation
|
|
- Cost-benefit analysis: 1.25x write cost vs 0.1x read savings
|
|
|
|
Cache Strategy:
|
|
1. Block 1: System prompt (1h TTL if ≥1024 tokens)
|
|
2. Blocks 2-4: Adaptive conversation chunks with mixed TTL
|
|
3. Early aggressive caching for quick wins
|
|
4. Late conservative caching to preserve blocks
|
|
|
|
Achieves 70-90% cost/latency savings while scaling efficiently
|
|
from 200k to 1M+ token context windows.
|
|
|
|
Based on Anthropic documentation and mathematical optimization (Sept 2025).
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
from core.utils.logger import logger
|
|
|
|
|
|
def get_resolved_model_id(model_name: str) -> str:
|
|
"""Resolve model name to its canonical ID through the model registry."""
|
|
try:
|
|
from core.ai_models.registry import ModelRegistry
|
|
registry = ModelRegistry()
|
|
model = registry.get(model_name)
|
|
if model:
|
|
resolved_id = model.id
|
|
if resolved_id != model_name:
|
|
logger.debug(f"Resolved model '{model_name}' to '{resolved_id}'")
|
|
return resolved_id
|
|
else:
|
|
logger.debug(f"Could not resolve model '{model_name}', using as-is")
|
|
return model_name
|
|
except Exception as e:
|
|
logger.warning(f"Error resolving model name: {e}")
|
|
return model_name
|
|
|
|
|
|
def is_anthropic_model(model_name: str) -> bool:
|
|
"""Check if model supports Anthropic prompt caching."""
|
|
resolved_model = get_resolved_model_id(model_name).lower()
|
|
return any(provider in resolved_model for provider in ['anthropic', 'claude', 'sonnet', 'haiku', 'opus'])
|
|
|
|
|
|
def estimate_token_count(text: str, model: str = "claude-3-5-sonnet-20240620") -> int:
|
|
"""
|
|
Accurate token counting using LiteLLM's token_counter.
|
|
Uses model-specific tokenizers when available, falls back to tiktoken.
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
try:
|
|
from litellm import token_counter
|
|
# Use LiteLLM's token counter with the specific model
|
|
return token_counter(model=model, text=str(text))
|
|
except Exception as e:
|
|
logger.warning(f"LiteLLM token counting failed: {e}, using fallback estimation")
|
|
# Fallback to word-based estimation
|
|
word_count = len(str(text).split())
|
|
return int(word_count * 1.3)
|
|
|
|
def get_message_token_count(message: Dict[str, Any], model: str = "claude-3-5-sonnet-20240620") -> int:
|
|
"""Get estimated token count for a message."""
|
|
content = message.get('content', '')
|
|
if isinstance(content, list):
|
|
total_tokens = 0
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get('type') == 'text':
|
|
total_tokens += estimate_token_count(item.get('text', ''), model)
|
|
return total_tokens
|
|
return estimate_token_count(str(content), model)
|
|
|
|
def get_messages_token_count(messages: List[Dict[str, Any]], model: str = "claude-3-5-sonnet-20240620") -> int:
|
|
"""Get total token count for a list of messages."""
|
|
return sum(get_message_token_count(msg, model) for msg in messages)
|
|
|
|
|
|
def calculate_optimal_cache_threshold(
|
|
context_window: int,
|
|
message_count: int,
|
|
current_tokens: int
|
|
) -> int:
|
|
"""
|
|
Calculate mathematically optimized cache threshold based on:
|
|
1. Context window size (larger windows = larger thresholds)
|
|
2. Conversation stage (early vs late)
|
|
3. TTL considerations (5min vs 1h)
|
|
4. Cost-benefit analysis
|
|
|
|
Formula considerations:
|
|
- Early conversation: Lower thresholds for quick cache benefits
|
|
- Large context windows: Higher thresholds to avoid preoccupying blocks
|
|
- Cost efficiency: Balance 1.25x write cost vs 0.1x read savings
|
|
"""
|
|
|
|
# Base threshold as percentage of context window
|
|
# For 200k: 2.5% = 5k, For 1M: 2.5% = 25k
|
|
base_threshold = int(context_window * 0.025)
|
|
|
|
# Conversation stage factor - scaled for real-world thread lengths
|
|
if message_count <= 20:
|
|
# Early conversation: Aggressive caching for quick wins
|
|
stage_multiplier = 0.3 # 30% of base (1.5k for 200k, 7.5k for 1M)
|
|
elif message_count <= 100:
|
|
# Growing conversation: Balanced approach
|
|
stage_multiplier = 0.6 # 60% of base (3k for 200k, 15k for 1M)
|
|
elif message_count <= 500:
|
|
# Mature conversation: Larger chunks to preserve blocks
|
|
stage_multiplier = 1.0 # 100% of base (5k for 200k, 25k for 1M)
|
|
else:
|
|
# Very long conversation (500+ messages): Conservative to maximize efficiency
|
|
stage_multiplier = 1.8 # 180% of base (9k for 200k, 45k for 1M)
|
|
|
|
# Context window scaling
|
|
if context_window >= 2_000_000:
|
|
# Massive context (Gemini 2.5 Pro): Very large chunks
|
|
context_multiplier = 2.0
|
|
elif context_window >= 1_000_000:
|
|
# Very large context: Can afford larger chunks
|
|
context_multiplier = 1.5
|
|
elif context_window >= 500_000:
|
|
# Large context: Moderate scaling
|
|
context_multiplier = 1.2
|
|
else:
|
|
# Standard context: Conservative
|
|
context_multiplier = 1.0
|
|
|
|
# Current token density adjustment
|
|
if current_tokens > 0:
|
|
avg_tokens_per_message = current_tokens / message_count
|
|
if avg_tokens_per_message > 1000:
|
|
# High token density: Increase threshold to avoid micro-chunks
|
|
density_multiplier = 1.3
|
|
elif avg_tokens_per_message < 200:
|
|
# Low token density: Decrease threshold for more granular caching
|
|
density_multiplier = 0.8
|
|
else:
|
|
density_multiplier = 1.0
|
|
else:
|
|
density_multiplier = 1.0
|
|
|
|
# Calculate final threshold
|
|
optimal_threshold = int(base_threshold * stage_multiplier * context_multiplier * density_multiplier)
|
|
|
|
# Enforce bounds
|
|
min_threshold = max(1024, int(context_window * 0.005)) # At least 1024 tokens or 0.5% of context
|
|
max_threshold = int(context_window * 0.15) # No more than 15% of context window
|
|
|
|
final_threshold = max(min_threshold, min(optimal_threshold, max_threshold))
|
|
|
|
from core.utils.logger import logger
|
|
logger.info(f"🧮 Calculated optimal cache threshold: {final_threshold} tokens")
|
|
logger.debug(f" Context: {context_window}, Messages: {message_count}, Current: {current_tokens}")
|
|
logger.debug(f" Factors - Stage: {stage_multiplier:.1f}, Context: {context_multiplier:.1f}, Density: {density_multiplier:.1f}")
|
|
|
|
return final_threshold
|
|
|
|
|
|
def add_cache_control(message: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Add cache_control to a message."""
|
|
content = message.get('content', '')
|
|
role = message.get('role', '')
|
|
|
|
# If already in list format with cache_control, return as-is
|
|
if isinstance(content, list):
|
|
if content and isinstance(content[0], dict) and 'cache_control' in content[0]:
|
|
return message
|
|
# Convert existing list format to cached format
|
|
text_content = ""
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get('type') == 'text':
|
|
text_content += item.get('text', '')
|
|
content = text_content
|
|
|
|
return {
|
|
"role": role,
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": str(content),
|
|
"cache_control": {"type": "ephemeral"}
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
def apply_anthropic_caching_strategy(
|
|
working_system_prompt: Dict[str, Any],
|
|
conversation_messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
context_window_tokens: Optional[int] = None, # Auto-detect from model registry
|
|
cache_threshold_tokens: Optional[int] = None # Auto-calculate based on context window
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply mathematically optimized token-based caching strategy for Anthropic models.
|
|
|
|
Dynamic Strategy:
|
|
- Auto-detects context window from model registry (200k-1M+ tokens)
|
|
- Calculates optimal cache thresholds based on conversation stage & context size
|
|
- Early conversations: Aggressive caching (2k-10k tokens) for quick wins
|
|
- Late conversations: Conservative caching (6k-30k tokens) to preserve blocks
|
|
- Adapts to token density (high/low verbosity conversations)
|
|
|
|
Mathematical Factors:
|
|
- Base threshold: 2.5% of context window
|
|
- Stage multiplier: 0.3x (≤20 msgs) → 0.6x (≤100 msgs) → 1.0x (≤500 msgs) → 1.8x (500+ msgs)
|
|
- Context multiplier: 1.0x (200k) → 1.2x (500k) → 1.5x (1M+) → 2.0x (2M+)
|
|
- Density multiplier: 0.8x (sparse) → 1.0x (normal) → 1.3x (dense)
|
|
|
|
This prevents cache invalidation while optimizing for context window utilization
|
|
and cost efficiency across different conversation patterns.
|
|
"""
|
|
if not conversation_messages:
|
|
conversation_messages = []
|
|
|
|
# Return early for non-Anthropic models
|
|
if not is_anthropic_model(model_name):
|
|
logger.debug(f"Model {model_name} doesn't support Anthropic caching")
|
|
# Filter out system messages to prevent duplication
|
|
filtered_conversation = [msg for msg in conversation_messages if msg.get('role') != 'system']
|
|
if len(filtered_conversation) < len(conversation_messages):
|
|
logger.debug(f"🔧 Filtered out {len(conversation_messages) - len(filtered_conversation)} system messages")
|
|
return [working_system_prompt] + filtered_conversation
|
|
|
|
# Get context window from model registry
|
|
if context_window_tokens is None:
|
|
try:
|
|
from core.ai_models.registry import registry
|
|
context_window_tokens = registry.get_context_window(model_name, default=200_000)
|
|
logger.debug(f"Retrieved context window from registry: {context_window_tokens} tokens")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get context window from registry: {e}")
|
|
context_window_tokens = 200_000 # Safe default
|
|
|
|
# Calculate mathematically optimized cache threshold
|
|
if cache_threshold_tokens is None:
|
|
cache_threshold_tokens = calculate_optimal_cache_threshold(
|
|
context_window_tokens,
|
|
len(conversation_messages),
|
|
get_messages_token_count(conversation_messages, model_name) if conversation_messages else 0
|
|
)
|
|
|
|
logger.info(f"📊 Applying single cache breakpoint strategy for {len(conversation_messages)} messages")
|
|
|
|
# Filter out any existing system messages from conversation
|
|
system_msgs_in_conversation = [msg for msg in conversation_messages if msg.get('role') == 'system']
|
|
if system_msgs_in_conversation:
|
|
original_count = len(conversation_messages)
|
|
conversation_messages = [msg for msg in conversation_messages if msg.get('role') != 'system']
|
|
logger.info(f"🔧 Filtered out {original_count - len(conversation_messages)} system messages to prevent duplication")
|
|
|
|
prepared_messages = []
|
|
|
|
# Block 1: System prompt (cache if ≥1024 tokens)
|
|
system_tokens = get_message_token_count(working_system_prompt, model_name)
|
|
if system_tokens >= 1024: # Anthropic's minimum cacheable size
|
|
cached_system = add_cache_control(working_system_prompt)
|
|
prepared_messages.append(cached_system)
|
|
logger.info(f"🔥 Block 1: Cached system prompt ({system_tokens} tokens, 1h TTL)")
|
|
blocks_used = 1
|
|
else:
|
|
prepared_messages.append(working_system_prompt)
|
|
logger.debug(f"System prompt too small for caching: {system_tokens} tokens")
|
|
blocks_used = 0
|
|
|
|
# Handle conversation messages with token-based chunked caching
|
|
if not conversation_messages:
|
|
logger.debug("No conversation messages to add")
|
|
return prepared_messages
|
|
|
|
total_conversation_tokens = get_messages_token_count(conversation_messages, model_name)
|
|
logger.info(f"📊 Processing {len(conversation_messages)} messages ({total_conversation_tokens} tokens)")
|
|
|
|
# Check if we have enough tokens to start caching
|
|
if total_conversation_tokens < 1024: # Below minimum cacheable size
|
|
prepared_messages.extend(conversation_messages)
|
|
logger.debug(f"Conversation too small for caching: {total_conversation_tokens} tokens")
|
|
return prepared_messages
|
|
|
|
# Token-based chunked caching strategy
|
|
max_conversation_blocks = 4 - blocks_used # Reserve blocks used by system prompt
|
|
|
|
# Calculate optimal chunk size to avoid context overflow
|
|
# Reserve ~20% of context window for new messages and outputs
|
|
max_cacheable_tokens = int(context_window_tokens * 0.8)
|
|
|
|
if total_conversation_tokens <= max_cacheable_tokens:
|
|
# Conversation fits within cache limits - use chunked approach
|
|
chunks_created = create_conversation_chunks(
|
|
conversation_messages,
|
|
cache_threshold_tokens,
|
|
max_conversation_blocks,
|
|
prepared_messages,
|
|
model_name
|
|
)
|
|
blocks_used += chunks_created
|
|
logger.info(f"✅ Created {chunks_created} conversation cache blocks")
|
|
else:
|
|
# Conversation too large - need summarization or truncation
|
|
logger.warning(f"Conversation ({total_conversation_tokens} tokens) exceeds cache limit ({max_cacheable_tokens})")
|
|
# For now, add recent messages only (could implement summarization here)
|
|
recent_token_limit = min(cache_threshold_tokens * 2, max_cacheable_tokens)
|
|
recent_messages = get_recent_messages_within_token_limit(conversation_messages, recent_token_limit, model_name)
|
|
prepared_messages.extend(recent_messages)
|
|
logger.info(f"Added {len(recent_messages)} recent messages ({get_messages_token_count(recent_messages, model_name)} tokens)")
|
|
|
|
logger.info(f"🎯 Total cache blocks used: {blocks_used}/4")
|
|
|
|
# Log final structure
|
|
cache_count = sum(1 for msg in prepared_messages
|
|
if isinstance(msg.get('content'), list) and
|
|
msg['content'] and
|
|
isinstance(msg['content'][0], dict) and
|
|
'cache_control' in msg['content'][0])
|
|
|
|
logger.info(f"✅ Final structure: {cache_count} cache breakpoints, {len(prepared_messages)} total blocks")
|
|
return prepared_messages
|
|
|
|
|
|
def create_conversation_chunks(
|
|
messages: List[Dict[str, Any]],
|
|
chunk_threshold_tokens: int,
|
|
max_blocks: int,
|
|
prepared_messages: List[Dict[str, Any]],
|
|
model: str = "claude-3-5-sonnet-20240620"
|
|
) -> int:
|
|
"""
|
|
Create conversation cache chunks based on token thresholds.
|
|
Returns number of cache blocks created.
|
|
"""
|
|
if not messages or max_blocks <= 0:
|
|
return 0
|
|
|
|
chunks_created = 0
|
|
current_chunk = []
|
|
current_chunk_tokens = 0
|
|
|
|
for i, message in enumerate(messages):
|
|
message_tokens = get_message_token_count(message, model)
|
|
|
|
# Check if adding this message would exceed threshold
|
|
if current_chunk_tokens + message_tokens > chunk_threshold_tokens and current_chunk:
|
|
# Create cache block for current chunk
|
|
if chunks_created < max_blocks - 1: # Reserve last block for final message
|
|
chunk_text = format_conversation_for_cache(current_chunk)
|
|
cache_block = {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"[Conversation Chunk {chunks_created + 1}]\n{chunk_text}",
|
|
"cache_control": {"type": "ephemeral", "ttl": "1h" if chunks_created == 0 else "5m"}
|
|
}
|
|
]
|
|
}
|
|
prepared_messages.append(cache_block)
|
|
chunks_created += 1
|
|
logger.info(f"🔥 Block {chunks_created + 1}: Cached chunk ({current_chunk_tokens} tokens, {len(current_chunk)} messages)")
|
|
|
|
# Reset for next chunk
|
|
current_chunk = []
|
|
current_chunk_tokens = 0
|
|
else:
|
|
# Hit max blocks - add remaining messages individually
|
|
prepared_messages.extend(current_chunk)
|
|
prepared_messages.extend(messages[i:])
|
|
logger.debug(f"Hit max blocks limit, added {len(messages) - i + len(current_chunk)} remaining messages uncached")
|
|
return chunks_created
|
|
|
|
current_chunk.append(message)
|
|
current_chunk_tokens += message_tokens
|
|
|
|
# Handle final chunk
|
|
if current_chunk:
|
|
if chunks_created < max_blocks and current_chunk_tokens >= 1024:
|
|
# Cache the final chunk
|
|
final_message = current_chunk[-1]
|
|
if len(current_chunk) > 1:
|
|
prepared_messages.extend(current_chunk[:-1])
|
|
|
|
cached_final = add_cache_control(final_message)
|
|
prepared_messages.append(cached_final)
|
|
chunks_created += 1
|
|
logger.info(f"🎯 Block {chunks_created + 1}: Cached final message ({get_message_token_count(final_message, model)} tokens)")
|
|
else:
|
|
# Add final chunk uncached
|
|
prepared_messages.extend(current_chunk)
|
|
logger.debug(f"Added final chunk uncached ({current_chunk_tokens} tokens, {len(current_chunk)} messages)")
|
|
|
|
return chunks_created
|
|
|
|
def get_recent_messages_within_token_limit(messages: List[Dict[str, Any]], token_limit: int, model: str = "claude-3-5-sonnet-20240620") -> List[Dict[str, Any]]:
|
|
"""Get the most recent messages that fit within the token limit."""
|
|
if not messages:
|
|
return []
|
|
|
|
recent_messages = []
|
|
total_tokens = 0
|
|
|
|
# Start from the end and work backwards
|
|
for message in reversed(messages):
|
|
message_tokens = get_message_token_count(message, model)
|
|
if total_tokens + message_tokens <= token_limit:
|
|
recent_messages.insert(0, message) # Insert at beginning to maintain order
|
|
total_tokens += message_tokens
|
|
else:
|
|
break
|
|
|
|
return recent_messages
|
|
|
|
def format_conversation_for_cache(messages: List[Dict[str, Any]]) -> str:
|
|
"""Format conversation messages into a single text block for caching."""
|
|
formatted_parts = []
|
|
|
|
for msg in messages:
|
|
role = msg.get('role', 'unknown')
|
|
content = msg.get('content', '')
|
|
|
|
# Handle different content formats
|
|
if isinstance(content, list):
|
|
# Extract text from list format
|
|
text_content = ""
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get('type') == 'text':
|
|
text_content += item.get('text', '')
|
|
elif not isinstance(item, dict) or 'cache_control' not in item:
|
|
text_content += str(item)
|
|
else:
|
|
text_content = str(content)
|
|
|
|
# Clean up and format
|
|
text_content = text_content.strip()
|
|
if text_content:
|
|
role_indicator = "User" if role == "user" else "Assistant" if role == "assistant" else role.title()
|
|
formatted_parts.append(f"{role_indicator}: {text_content}")
|
|
|
|
return "\n\n".join(formatted_parts)
|
|
|
|
|
|
def validate_cache_blocks(messages: List[Dict[str, Any]], model_name: str, max_blocks: int = 4) -> List[Dict[str, Any]]:
|
|
"""
|
|
Validate cache block count stays within Anthropic's 4-block limit.
|
|
With our 2-block strategy, this should never be an issue.
|
|
"""
|
|
if not is_anthropic_model(model_name):
|
|
return messages
|
|
|
|
cache_count = sum(1 for msg in messages
|
|
if isinstance(msg.get('content'), list) and
|
|
msg['content'] and
|
|
isinstance(msg['content'][0], dict) and
|
|
'cache_control' in msg['content'][0])
|
|
|
|
if cache_count <= max_blocks:
|
|
logger.debug(f"✅ Cache validation passed: {cache_count}/{max_blocks} blocks")
|
|
return messages
|
|
|
|
logger.warning(f"⚠️ Cache validation failed: {cache_count}/{max_blocks} blocks")
|
|
return messages # With 2-block strategy, this shouldn't happen
|